]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - decision_control/rrtplanner.cc
Return false if nn() in next() fails
[hubacji1/iamcar.git] / decision_control / rrtplanner.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <ctime>
21 #include <queue>
22 #include "compile.h"
23 #include "nn.h"
24 #include "nv.h"
25 #include "sample.h"
26 #include "steer.h"
27 #include "rrtplanner.h"
28 #include "cost.h"
29
30 #define CATI(a, b) a ## b
31 #define CAT(a, b) CATI(a, b)
32 #define KUWATA2008_CCOST CAT(c, CO)
33 #define KUWATA2008_DCOST CO
34
35 LaValle1998::LaValle1998(RRTNode *init, RRTNode *goal):
36         RRTBase(init, goal),
37         nn(NN),
38         sample(SA),
39         steer(ST),
40         cost(CO)
41 {
42         srand(static_cast<unsigned>(time(0)));
43 }
44
45 bool LaValle1998::next()
46 {
47         RRTNode *rs;
48 #if GOALFIRST > 0
49         if (this->samples().size() == 0)
50                 rs = this->goal();
51         else
52                 rs = this->sample();
53 #else
54         rs = this->sample();
55 #endif
56         this->samples().push_back(rs);
57 #if NNVERSION>1
58         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
59 #else
60         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
61 #endif
62         RRTNode *pn = nn;
63         bool en_add = true;
64         for (auto ns: this->steer(nn, rs)) {
65                 if (!en_add) {
66                         delete ns;
67                 } else {
68                         this->nodes().push_back(ns);
69                         this->add_iy(ns);
70                         pn->add_child(ns, this->cost(pn, ns));
71                         if (this->collide(pn, ns)) {
72                                 pn->children().pop_back();
73                                 ns->remove_parent();
74                                 this->iy_[IYI(ns->y())].pop_back();
75                                 en_add = false;
76                         } else {
77                                 this->ocost(ns);
78                                 pn = ns;
79                                 if (this->goal_found(pn, this->cost)) {
80                                         this->tlog(this->findt());
81                                         en_add = false;
82                                 }
83                         }
84                 }
85         }
86         return this->goal_found();
87 }
88
89 Kuwata2008::Kuwata2008(RRTNode *init, RRTNode *goal):
90         RRTBase(init, goal),
91         nn(NN),
92         sample(SA),
93         steer(ST),
94         cost(KUWATA2008_DCOST)
95 {
96         srand(static_cast<unsigned>(time(0)));
97 }
98
99 bool Kuwata2008::next()
100 {
101         RRTNode *rs;
102         if (this->samples().size() == 0) {
103                 rs = this->goal();
104         } else {
105                 rs = this->sample();
106         }
107         this->samples().push_back(rs);
108         float heur = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
109         if (this->goal_found()) {
110                 if (heur < 0.7)
111                         this->cost = &KUWATA2008_CCOST;
112                 else
113                         this->cost = &KUWATA2008_DCOST;
114         } else {
115                 if (heur < 0.3)
116                         this->cost = &KUWATA2008_CCOST;
117                 else
118                         this->cost = &KUWATA2008_DCOST;
119         }
120 #if NNVERSION>1
121         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
122 #else
123         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
124 #endif
125         RRTNode *pn = nn;
126         std::vector<RRTNode *> newly_added;
127         bool en_add = true;
128         for (auto ns: this->steer(nn, rs)) {
129                 if (!en_add) {
130                         delete ns;
131                 } else {
132                         this->nodes().push_back(ns);
133                         this->add_iy(ns);
134                         pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
135                         if (this->collide(pn, ns)) {
136                                 pn->children().pop_back();
137                                 ns->remove_parent();
138                                 this->iy_[IYI(ns->y())].pop_back();
139                                 en_add = false;
140                         } else {
141                                 this->ocost(ns);
142                                 pn = ns;
143                                 newly_added.push_back(pn);
144                                 if (this->goal_found(pn, &KUWATA2008_DCOST)) {
145                                         this->tlog(this->findt());
146                                         en_add = false;
147                                 }
148                         }
149                 }
150         }
151         if (this->samples().size() <= 1)
152                 return this->goal_found();
153         for (auto na: newly_added) {
154                 pn = na;
155                 en_add = true;
156                 for (auto ns: this->steer(na, this->goal())) {
157                         if (!en_add) {
158                                 delete ns;
159                         } else {
160                                 this->nodes().push_back(ns);
161                                 this->add_iy(ns);
162                                 pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
163                                 if (this->collide(pn, ns)) {
164                                         pn->children().pop_back();
165                                         ns->remove_parent();
166                                         this->iy_[IYI(ns->y())].pop_back();
167                                         en_add = false;
168                                 } else {
169                                         this->ocost(ns);
170                                         pn = ns;
171                                         if (this->goal_found(pn,
172                                                         &KUWATA2008_DCOST)) {
173                                                 this->tlog(this->findt());
174                                                 en_add = false;
175                                         }
176                                 }
177                         }
178                 }
179         }
180         return this->goal_found();
181 }
182
183 Karaman2011::Karaman2011(RRTNode *init, RRTNode *goal):
184         RRTBase(init, goal),
185         nn(NN),
186         nv(NV),
187         sample(SA),
188         steer(ST),
189         cost(CO)
190 {
191         srand(static_cast<unsigned>(time(0)));
192 }
193
194 bool Karaman2011::next()
195 {
196         RRTNode *rs;
197 #if GOALFIRST > 0
198         if (this->samples().size() == 0)
199                 rs = this->goal();
200         else
201                 rs = this->sample();
202 #else
203         rs = this->sample();
204 #endif
205         this->samples().push_back(rs);
206 #if NNVERSION>1
207         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
208 #else
209         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
210 #endif
211         RRTNode *pn = nn;
212         std::vector<RRTNode *> nvs;
213         bool en_add = true;
214         for (auto ns: this->steer(nn, rs)) {
215                 if (!en_add) {
216                         delete ns;
217                 } else {
218 #if NVVERSION>1
219                         nvs = this->nv(
220                                         this->iy_,
221                                         ns,
222                                         this->cost,
223                                         MIN(
224                                                 GAMMA_RRTSTAR(
225                                                         this->nodes().size()),
226                                                 0.2)); // TODO const
227 #else
228                         nvs = this->nv(
229                                         this->root(),
230                                         ns,
231                                         this->cost,
232                                         MIN(
233                                                 GAMMA_RRTSTAR(
234                                                         this->nodes().size()),
235                                                 0.2)); // TODO const
236 #endif
237                         this->nodes().push_back(ns);
238                         this->add_iy(ns);
239                         // connect
240                         if (!this->connect(pn, ns, nvs)) {
241                                 this->iy_[IYI(ns->y())].pop_back();
242                                 en_add = false;
243                         } else {
244                                 // rewire
245                                 this->rewire(nvs, ns);
246                                 pn = ns;
247                                 if (this->goal_found(pn, this->cost)) {
248                                         this->tlog(this->findt());
249                                         en_add = false;
250                                 }
251                         }
252                 }
253         }
254         return this->goal_found();
255 }
256
257 bool Karaman2011::connect(
258                 RRTNode *pn,
259                 RRTNode *ns,
260                 std::vector<RRTNode *> nvs)
261 {
262         RRTNode *op; // old parent
263         float od; // old direct cost
264         float oc; // old cumulative cost
265         bool connected = false;
266         pn->add_child(ns, this->cost(pn, ns));
267         if (this->collide(pn, ns)) {
268                 pn->children().pop_back();
269                 ns->remove_parent();
270         } else {
271                 this->ocost(ns);
272                 connected = true;
273         }
274         for (auto nv: nvs) {
275                 if (!connected || (nv->ccost() + this->cost(nv, ns) <
276                                 ns->ccost())) {
277                         op = ns->parent();
278                         od = ns->dcost();
279                         oc = ns->ccost();
280                         nv->add_child(ns, this->cost(nv, ns));
281                         if (this->collide(nv, ns)) {
282                                 nv->children().pop_back();
283                                 if (op)
284                                         ns->parent(op);
285                                 else
286                                         ns->remove_parent();
287                                 ns->dcost(od);
288                                 ns->ccost(oc);
289                         } else if (connected) {
290                                 op->children().pop_back();
291                         } else {
292                                 this->ocost(ns);
293                                 connected = true;
294                         }
295                 }
296         }
297         return connected;
298 }
299
300 bool Karaman2011::rewire(std::vector<RRTNode *> nvs, RRTNode *ns)
301 {
302         RRTNode *op; // old parent
303         float od; // old direct cost
304         float oc; // old cumulative cost
305         for (auto nv: nvs) {
306                 if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
307                         op = nv->parent();
308                         od = nv->dcost();
309                         oc = nv->ccost();
310                         ns->add_child(nv, this->cost(ns, nv));
311                         if (this->collide(ns, nv)) {
312                                 ns->children().pop_back();
313                                 nv->parent(op);
314                                 nv->dcost(od);
315                                 nv->ccost(oc);
316                         } else {
317                                 op->rem_child(nv);
318                         }
319                 }
320         }
321         return true;
322 }
323
324 T1::T1(RRTNode *init, RRTNode *goal):
325         RRTBase(init, goal),
326         nn(NN),
327         nv(NV),
328         sample(SA),
329         steer(ST),
330         cost(CO)
331 {
332         srand(static_cast<unsigned>(time(0)));
333 }
334
335 bool T1::next()
336 {
337         RRTNode *rs;
338         if (this->samples().size() == 0)
339                 rs = this->goal();
340         else
341                 rs = this->sample();
342         this->samples().push_back(rs);
343 #if NNVERSION>1
344         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
345 #else
346         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
347 #endif
348         RRTNode *pn = nn;
349         std::vector<RRTNode *> nvs;
350         bool connected;
351         RRTNode *op; // old parent
352         float od; // old direct cost
353         float oc; // old cumulative cost
354         std::vector<RRTNode *> steered = this->steer(nn, rs);
355         // RRT* for first node
356         RRTNode *ns = steered[0];
357         {
358 #if NVVERSION>1
359                 nvs = this->nv(this->iy_, ns, this->cost, MIN(
360                                         GAMMA_RRTSTAR(this->nodes().size()),
361                                         0.2)); // TODO const
362 #else
363                 nvs = this->nv(this->root(), ns, this->cost, MIN(
364                                         GAMMA_RRTSTAR(this->nodes().size()),
365                                         0.2)); // TODO const
366 #endif
367                 this->nodes().push_back(ns);
368                 this->add_iy(ns);
369                 connected = false;
370                 pn->add_child(ns, this->cost(pn, ns));
371                 if (this->collide(pn, ns)) {
372                         pn->children().pop_back();
373                 } else {
374                         connected = true;
375                 }
376                 // connect
377                 for (auto nv: nvs) {
378                         if (!connected || (nv->ccost() + this->cost(nv, ns) <
379                                         ns->ccost())) {
380                                 op = ns->parent();
381                                 od = ns->dcost();
382                                 oc = ns->ccost();
383                                 nv->add_child(ns, this->cost(nv, ns));
384                                 if (this->collide(nv, ns)) {
385                                         nv->children().pop_back();
386                                         ns->parent(op);
387                                         ns->dcost(od);
388                                         ns->ccost(oc);
389                                 } else if (connected) {
390                                         op->children().pop_back();
391                                 } else {
392                                         connected = true;
393                                 }
394                         }
395                 }
396                 if (!connected)
397                         return false;
398                 // rewire
399                 for (auto nv: nvs) {
400                         if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
401                                 op = nv->parent();
402                                 od = nv->dcost();
403                                 oc = nv->ccost();
404                                 ns->add_child(nv, this->cost(ns, nv));
405                                 if (this->collide(ns, nv)) {
406                                         ns->children().pop_back();
407                                         nv->parent(op);
408                                         nv->dcost(od);
409                                         nv->ccost(oc);
410                                 } else {
411                                         op->rem_child(nv);
412                                 }
413                         }
414                 }
415                 pn = ns;
416                 if (this->goal_found(pn, this->cost)) {
417                         this->tlog(this->findt());
418                 }
419         }
420         unsigned int i = 0;
421         for (i = 1; i < steered.size(); i++) {
422                 ns = steered[i];
423                 this->nodes().push_back(ns);
424                 this->add_iy(ns);
425                 pn->add_child(ns, this->cost(pn, ns));
426                 if (this->collide(pn, ns)) {
427                         pn->children().pop_back();
428                         break;
429                 }
430                 pn = ns;
431                 if (this->goal_found(pn, this->cost)) {
432                         this->tlog(this->findt());
433                         break;
434                 }
435         }
436         return this->goal_found();
437 }
438
439 bool T2::next()
440 {
441         RRTNode *rs;
442 #if GOALFIRST > 0
443         if (this->samples().size() == 0)
444                 rs = this->goal();
445         else
446                 rs = this->sample();
447 #else
448         rs = this->sample();
449 #endif
450         this->samples().push_back(rs);
451 #if NNVERSION>1
452         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
453 #else
454         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
455 #endif
456         if (!nn)
457                 return false;
458         RRTNode *pn = nn;
459         std::vector<RRTNode *> nvs;
460         std::vector<RRTNode *> newly_added;
461         bool en_add = true;
462         int cusps = 0;
463         for (auto ns: this->steer(nn, rs)) {
464                 if (!en_add) {
465                         delete ns;
466                 } else if (IS_NEAR(pn, ns)) {
467                         delete ns;
468                 } else {
469                         if (sgn(pn->s()) != sgn(ns->s()))
470                                 cusps++;
471                         if (cusps > 4)
472                                 en_add = false;
473 #if NVVERSION>1
474                         nvs = this->nv(
475                                         this->iy_,
476                                         ns,
477                                         this->cost,
478                                         MIN(
479                                                 GAMMA_RRTSTAR(
480                                                         this->nodes().size()),
481                                                 0.2)); // TODO const
482 #else
483                         nvs = this->nv(
484                                         this->root(),
485                                         ns,
486                                         this->cost,
487                                         MIN(
488                                                 GAMMA_RRTSTAR(
489                                                         this->nodes().size()),
490                                                 0.2)); // TODO const
491 #endif
492                         this->nodes().push_back(ns);
493                         this->add_iy(ns);
494                         // connect
495                         if (!this->connect(pn, ns, nvs)) {
496                                 this->iy_[IYI(ns->y())].pop_back();
497                                 en_add = false;
498                         } else {
499                                 // rewire
500                                 this->rewire(nvs, ns);
501                                 pn = ns;
502                                 newly_added.push_back(pn);
503                                 if (this->goal_found(pn, this->cost)) {
504                                         this->goal_cost();
505                                         this->tlog(this->findt());
506                                         en_add = false;
507                                 }
508                         }
509                 }
510         }
511         if (this->samples().size() <= 1)
512                 return this->goal_found();
513         for (auto na: newly_added) {
514                 pn = na;
515                 en_add = true;
516                 cusps = 0;
517                 for (auto ns: this->steer(na, this->goal())) {
518                         if (!en_add) {
519                                 delete ns;
520                         } else if (IS_NEAR(pn, ns)) {
521                                 delete ns;
522                         } else {
523                                 if (sgn(pn->s()) != sgn(ns->s()))
524                                         cusps++;
525                                 if (cusps > 4)
526                                         en_add = false;
527                                 this->nodes().push_back(ns);
528                                 this->add_iy(ns);
529                                 pn->add_child(ns, this->cost(pn, ns));
530                                 if (this->collide(pn, ns)) {
531                                         pn->children().pop_back();
532                                         ns->remove_parent();
533                                         this->iy_[IYI(ns->y())].pop_back();
534                                         en_add = false;
535                                 } else {
536                                         this->ocost(ns);
537                                         pn = ns;
538                                         if (this->goal_found(pn, this->cost)) {
539                                                 this->goal_cost();
540                                                 this->tlog(this->findt());
541                                                 en_add = false;
542                                         }
543                                 }
544                         }
545                 }
546         }
547         return this->goal_found();
548 }
549
550 float T2::goal_cost()
551 {
552         std::vector<RRTNode *> nvs;
553 #if NVVERSION>1
554         nvs = this->nv(
555                         this->iy_,
556                         this->goal(),
557                         this->cost,
558                         0.2);
559 #else
560         nvs = this->nv(
561                         this->root(),
562                         this->goal(),
563                         this->cost,
564                         0.2);
565 #endif
566         for (auto nv: nvs) {
567                 if (std::abs(this->goal()->h() - nv->h()) >=
568                                 this->GOAL_FOUND_ANGLE)
569                         continue;
570                 if (nv->ccost() + (*cost)(nv, this->goal()) >=
571                                 this->goal()->ccost())
572                         continue;
573                 RRTNode *op; // old parent
574                 float oc; // old cumulative cost
575                 float od; // old direct cost
576                 op = this->goal()->parent();
577                 oc = this->goal()->ccost();
578                 od = this->goal()->dcost();
579                 nv->add_child(this->goal(),
580                                 (*cost)(nv, this->goal()));
581                 if (this->collide(nv, this->goal())) {
582                         nv->children().pop_back();
583                         this->goal()->parent(op);
584                         this->goal()->ccost(oc);
585                         this->goal()->dcost(od);
586                 } else {
587                         op->rem_child(this->goal());
588                 }
589         }
590         return this->goal()->ccost();
591 }
592
593 class RRTNodeDijkstra {
594         public:
595                 RRTNodeDijkstra(int i):
596                         ni(i),
597                         pi(0),
598                         c(9999),
599                         v(false)
600                 {};
601                 RRTNodeDijkstra(int i, float c):
602                         ni(i),
603                         pi(0),
604                         c(c),
605                         v(false)
606                 {};
607                 unsigned int ni;
608                 unsigned int pi;
609                 float c;
610                 bool v;
611                 bool vi()
612                 {
613                         if (this->v)
614                                 return true;
615                         this->v = true;
616                         return false;
617                 };
618 };
619
620 class RRTNodeDijkstraComparator {
621         public:
622                 int operator() (
623                                 const RRTNodeDijkstra& n1,
624                                 const RRTNodeDijkstra& n2)
625                 {
626                         return n1.c > n2.c;
627                 }
628 };
629
630 bool T2::opt_path()
631 {
632         if (this->tlog().size() == 0)
633                 return false;
634         float oc = this->tlog().back().front()->ccost();
635         std::vector<RRTNode *> tmp_cusps;
636         for (auto n: this->tlog().back()) {
637                 if (n->parent() && sgn(n->s()) != sgn(n->parent()->s())) {
638                                 tmp_cusps.push_back(n);
639                                 tmp_cusps.push_back(n->parent());
640                 }
641         }
642         std::vector<RRTNode *> cusps;
643         for (unsigned int i = 0; i < tmp_cusps.size() - 1; i++) {
644                 if (tmp_cusps[i] != tmp_cusps[i + 1])
645                         cusps.push_back(tmp_cusps[i]);
646         }
647         cusps.push_back(this->root());
648         std::reverse(cusps.begin(), cusps.end());
649         cusps.push_back(this->tlog().back().front());
650         // Begin of Dijkstra
651         std::vector<RRTNodeDijkstra> dnodes;
652         for (unsigned int i = 0; i < cusps.size(); i++)
653                 dnodes.push_back(RRTNodeDijkstra(i));
654         dnodes[0].c = 0;
655         dnodes[0].vi();
656         std::priority_queue<
657                 RRTNodeDijkstra,
658                 std::vector<RRTNodeDijkstra>,
659                 RRTNodeDijkstraComparator> pq;
660         RRTNodeDijkstra tmp = dnodes[0];
661         pq.push(tmp);
662         float ch_cost = 9999;
663         std::vector<RRTNode *> steered;
664         while (!pq.empty() && tmp.ni != cusps.size() - 1) {
665                 tmp = pq.top();
666                 pq.pop();
667                 for (unsigned int i = tmp.ni + 1; i < cusps.size(); i++) {
668                         ch_cost = dnodes[tmp.ni].c +
669                                 this->cost(cusps[tmp.ni], cusps[i]);
670                         steered = this->steer(cusps[tmp.ni], cusps[i]);
671                         for (unsigned int j = 0; j < steered.size() - 1; j++)
672                                 steered[j]->add_child(
673                                                 steered[j + 1],
674                                                 this->cost(
675                                                         steered[j],
676                                                         steered[j + 1]));
677                         if (this->collide(
678                                                 steered[0],
679                                                 steered[steered.size() - 1]))
680                                 ch_cost = 9999 + 1;
681                         if (ch_cost < dnodes[i].c) {
682                                 dnodes[i].c = ch_cost;
683                                 dnodes[i].pi = tmp.ni;
684                                 if (!dnodes[i].vi())
685                                         pq.push(dnodes[i]);
686                         }
687                 }
688         }
689         std::vector<int> npi; // new path indexes
690         int tmpi = tmp.ni;
691         while (tmpi > 0) {
692                 npi.push_back(tmpi);
693                 tmpi = dnodes[tmpi].pi;
694         }
695         npi.push_back(tmpi);
696         std::vector<RRTNode *> npn; // new path nodes
697         std::reverse(npi.begin(), npi.end());
698         RRTNode *pn = cusps[0];
699         for (unsigned int i = 0; i < npi.size() - 1; i++) {
700                 for (auto ns: this->steer(cusps[npi[i]], cusps[npi[i + 1]])) {
701                         pn->add_child(ns, this->cost(pn, ns));
702                         pn = ns;
703                 }
704         }
705         pn->add_child(
706                         this->tlog().back().front(),
707                         this->cost(pn, this->tlog().back().front()));
708         // End of Dijkstra
709         if (this->tlog().back().front()->ccost() < oc)
710                 return true;
711         return false;
712 }
713
714 bool T2::opt_part(RRTNode *init, RRTNode *goal)
715 {
716         std::vector<RRTNode *> steered;
717         steered = this->steer(init, goal);
718         unsigned int ss = steered.size();
719         for (unsigned int i = 0; i < ss - 1; i++) {
720                 steered[i]->add_child(
721                                 steered[i + 1],
722                                 this->cost(
723                                         steered[0], // FIXME
724                                         steered[ss - 1])); // FIXME
725         }
726         if (this->collide(steered[0], steered[ss - 1])) {
727                 for (auto n: steered)
728                         delete n;
729                 return false;
730         }
731         RRTNode *op;
732         op = init->parent();
733         if (!op)
734                 op = init;
735         op->add_child(steered[0], this->cost(op, steered[0]));
736         steered[ss - 1]->add_child(goal, this->cost(steered[ss - 1], goal));
737         return true;
738 }