]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - decision_control/rrtplanner.cc
Do not remove nodes from RRT
[hubacji1/iamcar.git] / decision_control / rrtplanner.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <ctime>
21 #include "compile.h"
22 #include "nn.h"
23 #include "nv.h"
24 #include "sample.h"
25 #include "steer.h"
26 #include "rrtplanner.h"
27 #include "cost.h"
28
29 #define CATI(a, b) a ## b
30 #define CAT(a, b) CATI(a, b)
31 #define KUWATA2008_CCOST CAT(c, CO)
32 #define KUWATA2008_DCOST CO
33
34 LaValle1998::LaValle1998(RRTNode *init, RRTNode *goal):
35         RRTBase(init, goal),
36         nn(NN),
37         sample(SA),
38         steer(ST),
39         cost(CO)
40 {
41         srand(static_cast<unsigned>(time(0)));
42 }
43
44 bool LaValle1998::next()
45 {
46         RRTNode *rs;
47 #if GOALFIRST > 0
48         if (this->samples().size() == 0)
49                 rs = this->goal();
50         else
51                 rs = this->sample();
52 #else
53         rs = this->sample();
54 #endif
55         this->samples().push_back(rs);
56 #if NNVERSION>1
57         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
58 #else
59         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
60 #endif
61         RRTNode *pn = nn;
62         bool en_add = true;
63         for (auto ns: this->steer(nn, rs)) {
64                 if (!en_add) {
65                         delete ns;
66                 } else {
67                         this->nodes().push_back(ns);
68                         this->add_iy(ns);
69                         pn->add_child(ns, this->cost(pn, ns));
70                         if (this->collide(pn, ns)) {
71                                 pn->children().pop_back();
72                                 ns->remove_parent();
73                                 this->iy_[IYI(ns->y())].pop_back();
74                                 en_add = false;
75                         } else {
76                                 this->ocost(ns);
77                                 pn = ns;
78                                 if (this->goal_found(pn, this->cost)) {
79                                         this->tlog(this->findt());
80                                         en_add = false;
81                                 }
82                         }
83                 }
84         }
85         return this->goal_found();
86 }
87
88 Kuwata2008::Kuwata2008(RRTNode *init, RRTNode *goal):
89         RRTBase(init, goal),
90         nn(NN),
91         sample(SA),
92         steer(ST),
93         cost(KUWATA2008_DCOST)
94 {
95         srand(static_cast<unsigned>(time(0)));
96 }
97
98 bool Kuwata2008::next()
99 {
100         RRTNode *rs;
101         if (this->samples().size() == 0) {
102                 rs = this->goal();
103         } else {
104                 rs = this->sample();
105         }
106         this->samples().push_back(rs);
107         float heur = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
108         if (this->goal_found()) {
109                 if (heur < 0.7)
110                         this->cost = &KUWATA2008_CCOST;
111                 else
112                         this->cost = &KUWATA2008_DCOST;
113         } else {
114                 if (heur < 0.3)
115                         this->cost = &KUWATA2008_CCOST;
116                 else
117                         this->cost = &KUWATA2008_DCOST;
118         }
119 #if NNVERSION>1
120         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
121 #else
122         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
123 #endif
124         RRTNode *pn = nn;
125         std::vector<RRTNode *> newly_added;
126         bool en_add = true;
127         for (auto ns: this->steer(nn, rs)) {
128                 if (!en_add) {
129                         delete ns;
130                 } else {
131                         this->nodes().push_back(ns);
132                         this->add_iy(ns);
133                         pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
134                         if (this->collide(pn, ns)) {
135                                 pn->children().pop_back();
136                                 ns->remove_parent();
137                                 this->iy_[IYI(ns->y())].pop_back();
138                                 en_add = false;
139                         } else {
140                                 this->ocost(ns);
141                                 pn = ns;
142                                 newly_added.push_back(pn);
143                                 if (this->goal_found(pn, &KUWATA2008_DCOST)) {
144                                         this->tlog(this->findt());
145                                         en_add = false;
146                                 }
147                         }
148                 }
149         }
150         if (this->samples().size() <= 1)
151                 return this->goal_found();
152         for (auto na: newly_added) {
153                 pn = na;
154                 en_add = true;
155                 for (auto ns: this->steer(na, this->goal())) {
156                         if (!en_add) {
157                                 delete ns;
158                         } else {
159                                 this->nodes().push_back(ns);
160                                 this->add_iy(ns);
161                                 pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
162                                 if (this->collide(pn, ns)) {
163                                         pn->children().pop_back();
164                                         ns->remove_parent();
165                                         this->iy_[IYI(ns->y())].pop_back();
166                                         en_add = false;
167                                 } else {
168                                         this->ocost(ns);
169                                         pn = ns;
170                                         if (this->goal_found(pn,
171                                                         &KUWATA2008_DCOST)) {
172                                                 this->tlog(this->findt());
173                                                 en_add = false;
174                                         }
175                                 }
176                         }
177                 }
178         }
179         return this->goal_found();
180 }
181
182 Karaman2011::Karaman2011(RRTNode *init, RRTNode *goal):
183         RRTBase(init, goal),
184         nn(NN),
185         nv(NV),
186         sample(SA),
187         steer(ST),
188         cost(CO)
189 {
190         srand(static_cast<unsigned>(time(0)));
191 }
192
193 bool Karaman2011::next()
194 {
195         RRTNode *rs;
196 #if GOALFIRST > 0
197         if (this->samples().size() == 0)
198                 rs = this->goal();
199         else
200                 rs = this->sample();
201 #else
202         rs = this->sample();
203 #endif
204         this->samples().push_back(rs);
205 #if NNVERSION>1
206         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
207 #else
208         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
209 #endif
210         RRTNode *pn = nn;
211         std::vector<RRTNode *> nvs;
212         bool en_add = true;
213         for (auto ns: this->steer(nn, rs)) {
214                 if (!en_add) {
215                         delete ns;
216                 } else {
217 #if NVVERSION>1
218                         nvs = this->nv(
219                                         this->iy_,
220                                         ns,
221                                         this->cost,
222                                         MIN(
223                                                 GAMMA_RRTSTAR(
224                                                         this->nodes().size()),
225                                                 0.2)); // TODO const
226 #else
227                         nvs = this->nv(
228                                         this->root(),
229                                         ns,
230                                         this->cost,
231                                         MIN(
232                                                 GAMMA_RRTSTAR(
233                                                         this->nodes().size()),
234                                                 0.2)); // TODO const
235 #endif
236                         this->nodes().push_back(ns);
237                         this->add_iy(ns);
238                         // connect
239                         if (!this->connect(pn, ns, nvs)) {
240                                 this->iy_[IYI(ns->y())].pop_back();
241                                 en_add = false;
242                         } else {
243                                 // rewire
244                                 this->rewire(nvs, ns);
245                                 pn = ns;
246                                 if (this->goal_found(pn, this->cost)) {
247                                         this->tlog(this->findt());
248                                         en_add = false;
249                                 }
250                         }
251                 }
252         }
253         return this->goal_found();
254 }
255
256 bool Karaman2011::connect(
257                 RRTNode *pn,
258                 RRTNode *ns,
259                 std::vector<RRTNode *> nvs)
260 {
261         RRTNode *op; // old parent
262         float od; // old direct cost
263         float oc; // old cumulative cost
264         bool connected = false;
265         pn->add_child(ns, this->cost(pn, ns));
266         if (this->collide(pn, ns)) {
267                 pn->children().pop_back();
268                 ns->remove_parent();
269         } else {
270                 this->ocost(ns);
271                 connected = true;
272         }
273         for (auto nv: nvs) {
274                 if (!connected || (nv->ccost() + this->cost(nv, ns) <
275                                 ns->ccost())) {
276                         op = ns->parent();
277                         od = ns->dcost();
278                         oc = ns->ccost();
279                         nv->add_child(ns, this->cost(nv, ns));
280                         if (this->collide(nv, ns)) {
281                                 nv->children().pop_back();
282                                 if (op)
283                                         ns->parent(op);
284                                 else
285                                         ns->remove_parent();
286                                 ns->dcost(od);
287                                 ns->ccost(oc);
288                         } else if (connected) {
289                                 op->children().pop_back();
290                         } else {
291                                 this->ocost(ns);
292                                 connected = true;
293                         }
294                 }
295         }
296         return connected;
297 }
298
299 bool Karaman2011::rewire(std::vector<RRTNode *> nvs, RRTNode *ns)
300 {
301         RRTNode *op; // old parent
302         float od; // old direct cost
303         float oc; // old cumulative cost
304         for (auto nv: nvs) {
305                 if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
306                         op = nv->parent();
307                         od = nv->dcost();
308                         oc = nv->ccost();
309                         ns->add_child(nv, this->cost(ns, nv));
310                         if (this->collide(ns, nv)) {
311                                 ns->children().pop_back();
312                                 nv->parent(op);
313                                 nv->dcost(od);
314                                 nv->ccost(oc);
315                         } else {
316                                 op->rem_child(nv);
317                         }
318                 }
319         }
320         return true;
321 }
322
323 T1::T1(RRTNode *init, RRTNode *goal):
324         RRTBase(init, goal),
325         nn(NN),
326         nv(NV),
327         sample(SA),
328         steer(ST),
329         cost(CO)
330 {
331         srand(static_cast<unsigned>(time(0)));
332 }
333
334 bool T1::next()
335 {
336         RRTNode *rs;
337         if (this->samples().size() == 0)
338                 rs = this->goal();
339         else
340                 rs = this->sample();
341         this->samples().push_back(rs);
342 #if NNVERSION>1
343         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
344 #else
345         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
346 #endif
347         RRTNode *pn = nn;
348         std::vector<RRTNode *> nvs;
349         bool connected;
350         RRTNode *op; // old parent
351         float od; // old direct cost
352         float oc; // old cumulative cost
353         std::vector<RRTNode *> steered = this->steer(nn, rs);
354         // RRT* for first node
355         RRTNode *ns = steered[0];
356         {
357 #if NVVERSION>1
358                 nvs = this->nv(this->iy_, ns, this->cost, MIN(
359                                         GAMMA_RRTSTAR(this->nodes().size()),
360                                         0.2)); // TODO const
361 #else
362                 nvs = this->nv(this->root(), ns, this->cost, MIN(
363                                         GAMMA_RRTSTAR(this->nodes().size()),
364                                         0.2)); // TODO const
365 #endif
366                 this->nodes().push_back(ns);
367                 this->add_iy(ns);
368                 connected = false;
369                 pn->add_child(ns, this->cost(pn, ns));
370                 if (this->collide(pn, ns)) {
371                         pn->children().pop_back();
372                 } else {
373                         connected = true;
374                 }
375                 // connect
376                 for (auto nv: nvs) {
377                         if (!connected || (nv->ccost() + this->cost(nv, ns) <
378                                         ns->ccost())) {
379                                 op = ns->parent();
380                                 od = ns->dcost();
381                                 oc = ns->ccost();
382                                 nv->add_child(ns, this->cost(nv, ns));
383                                 if (this->collide(nv, ns)) {
384                                         nv->children().pop_back();
385                                         ns->parent(op);
386                                         ns->dcost(od);
387                                         ns->ccost(oc);
388                                 } else if (connected) {
389                                         op->children().pop_back();
390                                 } else {
391                                         connected = true;
392                                 }
393                         }
394                 }
395                 if (!connected)
396                         return false;
397                 // rewire
398                 for (auto nv: nvs) {
399                         if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
400                                 op = nv->parent();
401                                 od = nv->dcost();
402                                 oc = nv->ccost();
403                                 ns->add_child(nv, this->cost(ns, nv));
404                                 if (this->collide(ns, nv)) {
405                                         ns->children().pop_back();
406                                         nv->parent(op);
407                                         nv->dcost(od);
408                                         nv->ccost(oc);
409                                 } else {
410                                         op->rem_child(nv);
411                                 }
412                         }
413                 }
414                 pn = ns;
415                 if (this->goal_found(pn, this->cost)) {
416                         this->tlog(this->findt());
417                 }
418         }
419         unsigned int i = 0;
420         for (i = 1; i < steered.size(); i++) {
421                 ns = steered[i];
422                 this->nodes().push_back(ns);
423                 this->add_iy(ns);
424                 pn->add_child(ns, this->cost(pn, ns));
425                 if (this->collide(pn, ns)) {
426                         pn->children().pop_back();
427                         break;
428                 }
429                 pn = ns;
430                 if (this->goal_found(pn, this->cost)) {
431                         this->tlog(this->findt());
432                         break;
433                 }
434         }
435         return this->goal_found();
436 }
437
438 bool T2::next()
439 {
440         RRTNode *rs;
441 #if GOALFIRST > 0
442         if (this->samples().size() == 0)
443                 rs = this->goal();
444         else
445                 rs = this->sample();
446 #else
447         rs = this->sample();
448 #endif
449         this->samples().push_back(rs);
450 #if NNVERSION>1
451         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
452 #else
453         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
454 #endif
455         RRTNode *pn = nn;
456         std::vector<RRTNode *> nvs;
457         std::vector<RRTNode *> newly_added;
458         bool en_add = true;
459         int cusps = 0;
460         for (auto ns: this->steer(nn, rs)) {
461                 if (!en_add) {
462                         delete ns;
463                 } else if (IS_NEAR(pn, ns)) {
464                         delete ns;
465                 } else {
466                         if (sgn(pn->s()) != sgn(ns->s()))
467                                 cusps++;
468                         if (cusps > 4)
469                                 en_add = false;
470 #if NVVERSION>1
471                         nvs = this->nv(
472                                         this->iy_,
473                                         ns,
474                                         this->cost,
475                                         MIN(
476                                                 GAMMA_RRTSTAR(
477                                                         this->nodes().size()),
478                                                 0.2)); // TODO const
479 #else
480                         nvs = this->nv(
481                                         this->root(),
482                                         ns,
483                                         this->cost,
484                                         MIN(
485                                                 GAMMA_RRTSTAR(
486                                                         this->nodes().size()),
487                                                 0.2)); // TODO const
488 #endif
489                         this->nodes().push_back(ns);
490                         this->add_iy(ns);
491                         // connect
492                         if (!this->connect(pn, ns, nvs)) {
493                                 this->iy_[IYI(ns->y())].pop_back();
494                                 en_add = false;
495                         } else {
496                                 // rewire
497                                 this->rewire(nvs, ns);
498                                 pn = ns;
499                                 newly_added.push_back(pn);
500                                 if (this->goal_found(pn, this->cost)) {
501                                         this->goal_cost();
502                                         this->tlog(this->findt());
503                                         this->opt_path();
504                                         this->tlog(this->findt());
505                                         en_add = false;
506                                 }
507                         }
508                 }
509         }
510         if (this->samples().size() <= 1)
511                 return this->goal_found();
512         for (auto na: newly_added) {
513                 pn = na;
514                 en_add = true;
515                 cusps = 0;
516                 for (auto ns: this->steer(na, this->goal())) {
517                         if (!en_add) {
518                                 delete ns;
519                         } else if (IS_NEAR(pn, ns)) {
520                                 delete ns;
521                         } else {
522                                 if (sgn(pn->s()) != sgn(ns->s()))
523                                         cusps++;
524                                 if (cusps > 4)
525                                         en_add = false;
526                                 this->nodes().push_back(ns);
527                                 this->add_iy(ns);
528                                 pn->add_child(ns, this->cost(pn, ns));
529                                 if (this->collide(pn, ns)) {
530                                         pn->children().pop_back();
531                                         ns->remove_parent();
532                                         this->iy_[IYI(ns->y())].pop_back();
533                                         en_add = false;
534                                 } else {
535                                         this->ocost(ns);
536                                         pn = ns;
537                                         if (this->goal_found(pn, this->cost)) {
538                                                 this->goal_cost();
539                                                 this->tlog(this->findt());
540                                                 this->opt_path();
541                                                 this->tlog(this->findt());
542                                                 en_add = false;
543                                         }
544                                 }
545                         }
546                 }
547         }
548         return this->goal_found();
549 }
550
551 float T2::goal_cost()
552 {
553         std::vector<RRTNode *> nvs;
554 #if NVVERSION>1
555         nvs = this->nv(
556                         this->iy_,
557                         this->goal(),
558                         this->cost,
559                         0.2);
560 #else
561         nvs = this->nv(
562                         this->root(),
563                         this->goal(),
564                         this->cost,
565                         0.2);
566 #endif
567         for (auto nv: nvs) {
568                 if (std::abs(this->goal()->h() - nv->h()) >=
569                                 this->GOAL_FOUND_ANGLE)
570                         continue;
571                 if (nv->ccost() + (*cost)(nv, this->goal()) >=
572                                 this->goal()->ccost())
573                         continue;
574                 RRTNode *op; // old parent
575                 float oc; // old cumulative cost
576                 float od; // old direct cost
577                 op = this->goal()->parent();
578                 oc = this->goal()->ccost();
579                 od = this->goal()->dcost();
580                 nv->add_child(this->goal(),
581                                 (*cost)(nv, this->goal()));
582                 if (this->collide(nv, this->goal())) {
583                         nv->children().pop_back();
584                         this->goal()->parent(op);
585                         this->goal()->ccost(oc);
586                         this->goal()->dcost(od);
587                 } else {
588                         op->rem_child(this->goal());
589                 }
590         }
591         return this->goal()->ccost();
592 }
593
594 bool T2::opt_path()
595 {
596         std::vector<RRTNode *> cusps;
597         if (this->tlog().size() == 0)
598                 return false;
599         for (auto n: this->tlog().back()) {
600                 if (n->parent() && sgn(n->s()) != sgn(n->parent()->s()))
601                                 cusps.push_back(n);
602         }
603         cusps.push_back(this->root());
604         std::reverse(cusps.begin(), cusps.end());
605         cusps.push_back(this->goal());
606         int li = cusps.size() - 1;
607         int i = li - 1;
608         while (i >= 0) {
609                 if (this->opt_part(cusps[i], cusps[li]))
610                         i--;
611                 else
612                         li = i--;
613         }
614         return true;
615 }
616
617 bool T2::opt_part(RRTNode *init, RRTNode *goal)
618 {
619         std::vector<RRTNode *> steered;
620         steered = this->steer(init, goal);
621         for (unsigned int i = 0; i < steered.size() - 1; i++) {
622                 steered[i]->add_child(
623                                 steered[i + 1],
624                                 this->cost(
625                                         steered[0],
626                                         steered[steered.size() - 1]));
627         }
628         if (this->collide(steered[0], steered[steered.size() - 1])) {
629                 for (auto n: steered)
630                         delete n;
631                 return false;
632         }
633         RRTNode *op;
634         op = init->parent();
635         if (!op)
636                 op = init;
637         op->add_child(steered[0], this->cost(op, steered[0]));
638         }
639         steered[steered.size() - 1]->add_child(
640                         goal,
641                         this->cost(
642                                 steered[steered.size() - 1],
643                                 goal));
644         return true;
645 }