]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - decision_control/rrtplanner.cc
Add RRT*-Connect main loop
[hubacji1/iamcar.git] / decision_control / rrtplanner.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <cstdlib>
19 #include <ctime>
20 #include "compile.h"
21 #include "nn.h"
22 #include "nv.h"
23 #include "sample.h"
24 #include "steer.h"
25 #include "rrtplanner.h"
26 #include "cost.h"
27
28 #define CATI(a, b) a ## b
29 #define CAT(a, b) CATI(a, b)
30 #define KUWATA2008_CCOST CAT(c, CO)
31 #define KUWATA2008_DCOST CO
32
33 LaValle1998::LaValle1998(RRTNode *init, RRTNode *goal):
34         RRTBase(init, goal),
35         nn(NN),
36         sample(SA),
37         steer(ST),
38         cost(CO)
39 {
40         srand(static_cast<unsigned>(time(0)));
41 }
42
43 bool LaValle1998::next()
44 {
45         RRTNode *rs;
46 #if GOALFIRST > 0
47         if (this->samples().size() == 0)
48                 rs = this->goal();
49         else
50                 rs = this->sample();
51 #else
52         rs = this->sample();
53 #endif
54         this->samples().push_back(rs);
55 #if NNVERSION>2
56         RRTNode *nn = this->nn(this->iy_, rs, this->cost, '0');
57 #elif NNVERSION>1
58         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
59 #else
60         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
61 #endif
62         RRTNode *pn = nn;
63         bool en_add = true;
64         for (auto ns: this->steer(nn, rs)) {
65                 if (!en_add) {
66                         delete ns;
67                 } else {
68                         this->nodes().push_back(ns);
69                         this->add_iy(ns);
70                         pn->add_child(ns, this->cost(pn, ns));
71                         if (this->collide(pn, ns)) {
72                                 pn->children().pop_back();
73                                 ns->remove_parent();
74                                 this->iy_[IYI(ns->y())].pop_back();
75                                 en_add = false;
76                         } else {
77                                 this->ocost(ns);
78                                 pn = ns;
79                                 if (this->goal_found(pn, this->cost)) {
80                                         this->tlog(this->findt());
81                                         en_add = false;
82                                 }
83                         }
84                 }
85         }
86         return this->goal_found();
87 }
88
89 Kuwata2008::Kuwata2008(RRTNode *init, RRTNode *goal):
90         RRTBase(init, goal),
91         nn(NN),
92         sample(SA),
93         steer(ST),
94         cost(KUWATA2008_DCOST)
95 {
96         srand(static_cast<unsigned>(time(0)));
97 }
98
99 bool Kuwata2008::next()
100 {
101         RRTNode *rs;
102         if (this->samples().size() == 0) {
103                 rs = this->goal();
104         } else {
105                 rs = this->sample();
106         }
107         this->samples().push_back(rs);
108         float heur = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
109         if (this->goal_found()) {
110                 if (heur < 0.7)
111                         this->cost = &KUWATA2008_CCOST;
112                 else
113                         this->cost = &KUWATA2008_DCOST;
114         } else {
115                 if (heur < 0.3)
116                         this->cost = &KUWATA2008_CCOST;
117                 else
118                         this->cost = &KUWATA2008_DCOST;
119         }
120 #if NNVERSION>2
121         RRTNode *nn = this->nn(this->iy_, rs, this->cost, '0');
122 #elif NNVERSION>1
123         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
124 #else
125         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
126 #endif
127         RRTNode *pn = nn;
128         std::vector<RRTNode *> newly_added;
129         bool en_add = true;
130         for (auto ns: this->steer(nn, rs)) {
131                 if (!en_add) {
132                         delete ns;
133                 } else {
134                         this->nodes().push_back(ns);
135                         this->add_iy(ns);
136                         pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
137                         if (this->collide(pn, ns)) {
138                                 pn->children().pop_back();
139                                 ns->remove_parent();
140                                 this->iy_[IYI(ns->y())].pop_back();
141                                 en_add = false;
142                         } else {
143                                 this->ocost(ns);
144                                 pn = ns;
145                                 newly_added.push_back(pn);
146                                 if (this->goal_found(pn, &KUWATA2008_DCOST)) {
147                                         this->tlog(this->findt());
148                                         en_add = false;
149                                 }
150                         }
151                 }
152         }
153         if (this->samples().size() <= 1)
154                 return this->goal_found();
155         for (auto na: newly_added) {
156                 pn = na;
157                 en_add = true;
158                 for (auto ns: this->steer(na, this->goal())) {
159                         if (!en_add) {
160                                 delete ns;
161                         } else {
162                                 this->nodes().push_back(ns);
163                                 this->add_iy(ns);
164                                 pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
165                                 if (this->collide(pn, ns)) {
166                                         pn->children().pop_back();
167                                         ns->remove_parent();
168                                         this->iy_[IYI(ns->y())].pop_back();
169                                         en_add = false;
170                                 } else {
171                                         this->ocost(ns);
172                                         pn = ns;
173                                         if (this->goal_found(pn,
174                                                         &KUWATA2008_DCOST)) {
175                                                 this->tlog(this->findt());
176                                                 en_add = false;
177                                         }
178                                 }
179                         }
180                 }
181         }
182         return this->goal_found();
183 }
184
185 Karaman2011::Karaman2011(RRTNode *init, RRTNode *goal):
186         RRTBase(init, goal),
187         nn(NN),
188         nv(NV),
189         sample(SA),
190         steer(ST),
191         cost(CO)
192 {
193         srand(static_cast<unsigned>(time(0)));
194 }
195
196 bool Karaman2011::next()
197 {
198         RRTNode *rs;
199 #if GOALFIRST > 0
200         if (this->samples().size() == 0)
201                 rs = this->goal();
202         else
203                 rs = this->sample();
204 #else
205         rs = this->sample();
206 #endif
207         this->samples().push_back(rs);
208 #if NNVERSION>2
209         RRTNode *nn = this->nn(this->iy_, rs, this->cost, '0');
210 #elif NNVERSION>1
211         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
212 #else
213         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
214 #endif
215         RRTNode *pn = nn;
216         std::vector<RRTNode *> nvs;
217         bool en_add = true;
218         for (auto ns: this->steer(nn, rs)) {
219                 if (!en_add) {
220                         delete ns;
221                 } else {
222 #if NVVERSION>2
223         nvs = this->nv(
224                         this->iy_,
225                         ns,
226                         this->cost,
227                         MIN(
228                                 GAMMA_RRTSTAR(
229                                         this->nodes().size()),
230                                 0.2),
231                         '0'); // TODO const
232 #elif NVVERSION>1
233                         nvs = this->nv(
234                                         this->iy_,
235                                         ns,
236                                         this->cost,
237                                         MIN(
238                                                 GAMMA_RRTSTAR(
239                                                         this->nodes().size()),
240                                                 0.2)); // TODO const
241 #else
242                         nvs = this->nv(
243                                         this->root(),
244                                         ns,
245                                         this->cost,
246                                         MIN(
247                                                 GAMMA_RRTSTAR(
248                                                         this->nodes().size()),
249                                                 0.2)); // TODO const
250 #endif
251                         this->nodes().push_back(ns);
252                         this->add_iy(ns);
253                         // connect
254                         if (!this->connect(pn, ns, nvs)) {
255                                 this->iy_[IYI(ns->y())].pop_back();
256                                 en_add = false;
257                         } else {
258                                 // rewire
259                                 this->rewire(nvs, ns);
260                                 pn = ns;
261                                 if (this->goal_found(pn, this->cost)) {
262                                         this->tlog(this->findt());
263                                         en_add = false;
264                                 }
265                         }
266                 }
267         }
268         return this->goal_found();
269 }
270
271 bool Karaman2011::connect(
272                 RRTNode *pn,
273                 RRTNode *ns,
274                 std::vector<RRTNode *> nvs)
275 {
276         RRTNode *op; // old parent
277         float od; // old direct cost
278         float oc; // old cumulative cost
279         bool connected = false;
280         pn->add_child(ns, this->cost(pn, ns));
281         if (this->collide(pn, ns)) {
282                 pn->children().pop_back();
283                 ns->remove_parent();
284         } else {
285                 this->ocost(ns);
286                 connected = true;
287         }
288         for (auto nv: nvs) {
289                 if (!connected || (nv->ccost() + this->cost(nv, ns) <
290                                 ns->ccost())) {
291                         op = ns->parent();
292                         od = ns->dcost();
293                         oc = ns->ccost();
294                         nv->add_child(ns, this->cost(nv, ns));
295                         if (this->collide(nv, ns)) {
296                                 nv->children().pop_back();
297                                 if (op)
298                                         ns->parent(op);
299                                 else
300                                         ns->remove_parent();
301                                 ns->dcost(od);
302                                 ns->ccost(oc);
303                         } else if (connected) {
304                                 op->children().pop_back();
305                         } else {
306                                 this->ocost(ns);
307                                 connected = true;
308                         }
309                 }
310         }
311         return connected;
312 }
313
314 bool Karaman2011::rewire(std::vector<RRTNode *> nvs, RRTNode *ns)
315 {
316         RRTNode *op; // old parent
317         float od; // old direct cost
318         float oc; // old cumulative cost
319         for (auto nv: nvs) {
320                 if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
321                         op = nv->parent();
322                         od = nv->dcost();
323                         oc = nv->ccost();
324                         ns->add_child(nv, this->cost(ns, nv));
325                         if (this->collide(ns, nv)) {
326                                 ns->children().pop_back();
327                                 nv->parent(op);
328                                 nv->dcost(od);
329                                 nv->ccost(oc);
330                         } else {
331                                 op->rem_child(nv);
332                         }
333                 }
334         }
335         return true;
336 }
337
338 T1::T1(RRTNode *init, RRTNode *goal):
339         RRTBase(init, goal),
340         nn(NN),
341         nv(NV),
342         sample(SA),
343         steer(ST),
344         cost(CO)
345 {
346         srand(static_cast<unsigned>(time(0)));
347 }
348
349 bool T1::next()
350 {
351         RRTNode *rs;
352         if (this->samples().size() == 0)
353                 rs = this->goal();
354         else
355                 rs = this->sample();
356         this->samples().push_back(rs);
357 #if NNVERSION>2
358         RRTNode *nn = this->nn(this->iy_, rs, this->cost, '0');
359 #elif NNVERSION>1
360         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
361 #else
362         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
363 #endif
364         RRTNode *pn = nn;
365         std::vector<RRTNode *> nvs;
366         bool connected;
367         RRTNode *op; // old parent
368         float od; // old direct cost
369         float oc; // old cumulative cost
370         std::vector<RRTNode *> steered = this->steer(nn, rs);
371         // RRT* for first node
372         RRTNode *ns = steered[0];
373         {
374 #if NVVERSION>2
375         nvs = this->nv(
376                         this->iy_,
377                         ns,
378                         this->cost,
379                         MIN(
380                                 GAMMA_RRTSTAR(
381                                         this->nodes().size()),
382                                 0.2),
383                         '0'); // TODO const
384 #elif NVVERSION>1
385                 nvs = this->nv(this->iy_, ns, this->cost, MIN(
386                                         GAMMA_RRTSTAR(this->nodes().size()),
387                                         0.2)); // TODO const
388 #else
389                 nvs = this->nv(this->root(), ns, this->cost, MIN(
390                                         GAMMA_RRTSTAR(this->nodes().size()),
391                                         0.2)); // TODO const
392 #endif
393                 this->nodes().push_back(ns);
394                 this->add_iy(ns);
395                 connected = false;
396                 pn->add_child(ns, this->cost(pn, ns));
397                 if (this->collide(pn, ns)) {
398                         pn->children().pop_back();
399                 } else {
400                         connected = true;
401                 }
402                 // connect
403                 for (auto nv: nvs) {
404                         if (!connected || (nv->ccost() + this->cost(nv, ns) <
405                                         ns->ccost())) {
406                                 op = ns->parent();
407                                 od = ns->dcost();
408                                 oc = ns->ccost();
409                                 nv->add_child(ns, this->cost(nv, ns));
410                                 if (this->collide(nv, ns)) {
411                                         nv->children().pop_back();
412                                         ns->parent(op);
413                                         ns->dcost(od);
414                                         ns->ccost(oc);
415                                 } else if (connected) {
416                                         op->children().pop_back();
417                                 } else {
418                                         connected = true;
419                                 }
420                         }
421                 }
422                 if (!connected)
423                         return false;
424                 // rewire
425                 for (auto nv: nvs) {
426                         if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
427                                 op = nv->parent();
428                                 od = nv->dcost();
429                                 oc = nv->ccost();
430                                 ns->add_child(nv, this->cost(ns, nv));
431                                 if (this->collide(ns, nv)) {
432                                         ns->children().pop_back();
433                                         nv->parent(op);
434                                         nv->dcost(od);
435                                         nv->ccost(oc);
436                                 } else {
437                                         op->rem_child(nv);
438                                 }
439                         }
440                 }
441                 pn = ns;
442                 if (this->goal_found(pn, this->cost)) {
443                         this->tlog(this->findt());
444                 }
445         }
446         unsigned int i = 0;
447         for (i = 1; i < steered.size(); i++) {
448                 ns = steered[i];
449                 this->nodes().push_back(ns);
450                 this->add_iy(ns);
451                 pn->add_child(ns, this->cost(pn, ns));
452                 if (this->collide(pn, ns)) {
453                         pn->children().pop_back();
454                         break;
455                 }
456                 pn = ns;
457                 if (this->goal_found(pn, this->cost)) {
458                         this->tlog(this->findt());
459                         break;
460                 }
461         }
462         return this->goal_found();
463 }
464
465 bool T2::next()
466 {
467         RRTNode *rs;
468 #if GOALFIRST > 0
469         if (this->samples().size() == 0)
470                 rs = this->goal();
471         else
472                 rs = this->sample();
473 #else
474         rs = this->sample();
475 #endif
476         this->samples().push_back(rs);
477 #if NNVERSION>2
478         RRTNode *nn = this->nn(this->iy_, rs, this->cost, '0');
479 #elif NNVERSION>1
480         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
481 #else
482         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
483 #endif
484         if (!nn)
485                 return false;
486         RRTNode *pn = nn;
487         std::vector<RRTNode *> nvs;
488         std::vector<RRTNode *> newly_added;
489         bool en_add = true;
490         int cusps = 0;
491         for (auto ns: this->steer(nn, rs)) {
492                 if (!en_add) {
493                         delete ns;
494                 } else if (IS_NEAR(pn, ns)) {
495                         delete ns;
496                 } else {
497                         if (sgn(ns->s()) == 0 || sgn(pn->s()) != sgn(ns->s()))
498                                 cusps++;
499                         if (cusps > 4)
500                                 en_add = false;
501 #if NVVERSION>2
502         nvs = this->nv(
503                         this->iy_,
504                         ns,
505                         this->cost,
506                         MIN(
507                                 GAMMA_RRTSTAR(
508                                         this->nodes().size()),
509                                 0.2),
510                         '0'); // TODO const
511 #elif NVVERSION>1
512                         nvs = this->nv(
513                                         this->iy_,
514                                         ns,
515                                         this->cost,
516                                         MIN(
517                                                 GAMMA_RRTSTAR(
518                                                         this->nodes().size()),
519                                                 0.2)); // TODO const
520 #else
521                         nvs = this->nv(
522                                         this->root(),
523                                         ns,
524                                         this->cost,
525                                         MIN(
526                                                 GAMMA_RRTSTAR(
527                                                         this->nodes().size()),
528                                                 0.2)); // TODO const
529 #endif
530                         this->nodes().push_back(ns);
531                         this->add_iy(ns);
532                         // connect
533                         if (!this->connect(pn, ns, nvs)) {
534                                 this->iy_[IYI(ns->y())].pop_back();
535                                 en_add = false;
536                         } else {
537                                 // rewire
538                                 this->rewire(nvs, ns);
539                                 pn = ns;
540                                 newly_added.push_back(pn);
541                                 if (this->goal_found(pn, this->cost)) {
542                                         this->goal_cost();
543                                         this->tlog(this->findt());
544                                         en_add = false;
545                                 }
546                         }
547                 }
548         }
549         if (this->samples().size() <= 1)
550                 return this->goal_found();
551         for (auto na: newly_added) {
552                 pn = na;
553                 en_add = true;
554                 cusps = 0;
555                 for (auto ns: this->steer(na, this->goal())) {
556                         if (!en_add) {
557                                 delete ns;
558                         } else if (IS_NEAR(pn, ns)) {
559                                 delete ns;
560                         } else {
561                                 if (sgn(pn->s()) != sgn(ns->s()))
562                                         cusps++;
563                                 if (cusps > 4)
564                                         en_add = false;
565                                 this->nodes().push_back(ns);
566                                 this->add_iy(ns);
567                                 pn->add_child(ns, this->cost(pn, ns));
568                                 if (this->collide(pn, ns)) {
569                                         pn->children().pop_back();
570                                         ns->remove_parent();
571                                         this->iy_[IYI(ns->y())].pop_back();
572                                         en_add = false;
573                                 } else {
574                                         this->ocost(ns);
575                                         pn = ns;
576                                         if (this->goal_found(pn, this->cost)) {
577                                                 this->goal_cost();
578                                                 this->tlog(this->findt());
579                                                 en_add = false;
580                                         }
581                                 }
582                         }
583                 }
584         }
585         return this->goal_found();
586 }
587
588 float T2::goal_cost()
589 {
590         std::vector<RRTNode *> nvs;
591 #if NVVERSION>2
592         nvs = this->nv(
593                         this->iy_,
594                         this->goal(),
595                         this->cost,
596                         MIN(
597                                 GAMMA_RRTSTAR(
598                                         this->nodes().size()),
599                                 0.2),
600                         '0'); // TODO const
601 #elif NVVERSION>1
602         nvs = this->nv(
603                         this->iy_,
604                         this->goal(),
605                         this->cost,
606                         0.2);
607 #else
608         nvs = this->nv(
609                         this->root(),
610                         this->goal(),
611                         this->cost,
612                         0.2);
613 #endif
614         for (auto nv: nvs) {
615                 if (std::abs(this->goal()->h() - nv->h()) >=
616                                 this->GOAL_FOUND_ANGLE)
617                         continue;
618                 if (nv->ccost() + (*cost)(nv, this->goal()) >=
619                                 this->goal()->ccost())
620                         continue;
621                 RRTNode *op; // old parent
622                 float oc; // old cumulative cost
623                 float od; // old direct cost
624                 op = this->goal()->parent();
625                 oc = this->goal()->ccost();
626                 od = this->goal()->dcost();
627                 nv->add_child(this->goal(),
628                                 (*cost)(nv, this->goal()));
629                 if (this->collide(nv, this->goal())) {
630                         nv->children().pop_back();
631                         this->goal()->parent(op);
632                         this->goal()->ccost(oc);
633                         this->goal()->dcost(od);
634                 } else {
635                         op->rem_child(this->goal());
636                 }
637         }
638         return this->goal()->ccost();
639 }
640
641 T3::T3(RRTNode *init, RRTNode *goal):
642         RRTBase(init, goal),
643         p_root_(init, goal),
644         p_goal_(goal, init)
645 {
646         srand(static_cast<unsigned>(time(0)));
647 }
648
649 bool T3::next()
650 {
651         RRTNode *ron = nullptr;
652         RRTNode *gon = nullptr;
653         bool ret = false;
654         ret = this->p_root_.next();
655         ret &= this->p_goal_.next();
656         if (this->overlaptrees(&ron, &gon)) {
657                 if (this->connecttrees(ron, gon))
658                         this->goal_found(true);
659                 this->tlog(this->findt());
660                 ret &= true;
661         }
662         return ret;
663 }
664
665 bool T3::link_obstacles(
666         std::vector<CircleObstacle> *cobstacles,
667         std::vector<SegmentObstacle> *sobstacles)
668 {
669         bool ret = false;
670         ret = RRTBase::link_obstacles(cobstacles, sobstacles);
671         ret &= this->p_root_.link_obstacles(cobstacles, sobstacles);
672         ret &= this->p_goal_.link_obstacles(cobstacles, sobstacles);
673         return ret;
674 }
675
676 bool T3::connecttrees(RRTNode *rn, RRTNode *gn)
677 {
678         while (gn != this->goal()) {
679                 this->p_root_.nodes().push_back(new RRTNode(
680                                 gn->x(),
681                                 gn->y(),
682                                 gn->h()));
683                 rn->add_child(
684                                 this->p_root_.nodes().back(),
685                                 this->p_root_.cost(
686                                                 rn,
687                                                 this->p_root_.nodes().back()));
688                 rn = this->p_root_.nodes().back();
689                 gn = gn->parent();
690         }
691         rn->add_child(this->goal(), this->p_root_.cost(rn, this->goal()));
692         return true;
693 }
694
695 bool T3::overlaptrees(RRTNode **ron, RRTNode **gon)
696 {
697         for (auto rn: this->p_root_.nodes()) {
698                 if (rn->parent() == nullptr)
699                         continue;
700                 for (auto gn: this->p_goal_.nodes()) {
701                         if (gn->parent() == nullptr)
702                                 continue;
703                         if (IS_NEAR(rn, gn)) {
704                                 *ron = rn;
705                                 *gon = gn;
706                                 return true;
707                         }
708                 }
709         }
710         return false;
711 }
712
713 Klemm2015::Klemm2015(RRTNode *init, RRTNode *goal):
714         Karaman2011(init, goal),
715         orig_root_(init),
716         orig_goal_(goal)
717 {
718         srand(static_cast<unsigned>(time(0)));
719 }
720
721 bool Klemm2015::next()
722 {
723         RRTNode *xn = nullptr;
724         RRTNode *rs;
725 #if GOALFIRST > 0
726         if (this->samples().size() == 0)
727                 rs = this->goal();
728         else
729                 rs = this->sample();
730 #else
731         rs = this->sample();
732 #endif
733         this->samples().push_back(rs);
734         //std::cerr << "next" << std::endl;
735         if (this->extendstar1(rs, &xn) != 2) {
736         //        if (xn) {
737         //                std::cerr << "- xn: " << xn->x() << ", " << xn->y();
738         //                std::cerr << std::endl;
739         //        } else {
740         //                std::cerr << "- xn: nullptr" << std::endl;
741         //        }
742                 this->swap();
743                 this->connectstar(xn);
744         } else {
745                 this->swap();
746         }
747         return this->goal_found();
748 }
749
750 int Klemm2015::extendstar()
751 {
752         return 0;
753 }
754
755 int Klemm2015::connectstar()
756 {
757         return 0;
758 }
759
760 void Klemm2015::swap()
761 {
762         RRTNode *tmp;
763         tmp = this->root();
764         this->root(this->goal());
765         this->goal(tmp);
766 }