]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - decision_control/rrtplanner.cc
Fix max cusps for path from steer procedure
[hubacji1/iamcar.git] / decision_control / rrtplanner.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <cstdlib>
19 #include <ctime>
20 #include "compile.h"
21 #include "nn.h"
22 #include "nv.h"
23 #include "sample.h"
24 #include "steer.h"
25 #include "rrtplanner.h"
26 #include "cost.h"
27
28 #define CATI(a, b) a ## b
29 #define CAT(a, b) CATI(a, b)
30 #define KUWATA2008_CCOST CAT(c, CO)
31 #define KUWATA2008_DCOST CO
32
33 LaValle1998::LaValle1998(RRTNode *init, RRTNode *goal):
34         RRTBase(init, goal),
35         nn(NN),
36         sample(SA),
37         steer(ST),
38         cost(CO)
39 {
40         srand(static_cast<unsigned>(time(0)));
41 }
42
43 bool LaValle1998::next()
44 {
45         RRTNode *rs;
46 #if GOALFIRST > 0
47         if (this->samples().size() == 0)
48                 rs = this->goal();
49         else
50                 rs = this->sample();
51 #else
52         rs = this->sample();
53 #endif
54         this->samples().push_back(rs);
55 #if NNVERSION>1
56         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
57 #else
58         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
59 #endif
60         RRTNode *pn = nn;
61         bool en_add = true;
62         for (auto ns: this->steer(nn, rs)) {
63                 if (!en_add) {
64                         delete ns;
65                 } else {
66                         this->nodes().push_back(ns);
67                         this->add_iy(ns);
68                         pn->add_child(ns, this->cost(pn, ns));
69                         if (this->collide(pn, ns)) {
70                                 pn->children().pop_back();
71                                 ns->remove_parent();
72                                 this->iy_[IYI(ns->y())].pop_back();
73                                 en_add = false;
74                         } else {
75                                 this->ocost(ns);
76                                 pn = ns;
77                                 if (this->goal_found(pn, this->cost)) {
78                                         this->tlog(this->findt());
79                                         en_add = false;
80                                 }
81                         }
82                 }
83         }
84         return this->goal_found();
85 }
86
87 Kuwata2008::Kuwata2008(RRTNode *init, RRTNode *goal):
88         RRTBase(init, goal),
89         nn(NN),
90         sample(SA),
91         steer(ST),
92         cost(KUWATA2008_DCOST)
93 {
94         srand(static_cast<unsigned>(time(0)));
95 }
96
97 bool Kuwata2008::next()
98 {
99         RRTNode *rs;
100         if (this->samples().size() == 0) {
101                 rs = this->goal();
102         } else {
103                 rs = this->sample();
104         }
105         this->samples().push_back(rs);
106         float heur = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
107         if (this->goal_found()) {
108                 if (heur < 0.7)
109                         this->cost = &KUWATA2008_CCOST;
110                 else
111                         this->cost = &KUWATA2008_DCOST;
112         } else {
113                 if (heur < 0.3)
114                         this->cost = &KUWATA2008_CCOST;
115                 else
116                         this->cost = &KUWATA2008_DCOST;
117         }
118 #if NNVERSION>1
119         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
120 #else
121         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
122 #endif
123         RRTNode *pn = nn;
124         std::vector<RRTNode *> newly_added;
125         bool en_add = true;
126         for (auto ns: this->steer(nn, rs)) {
127                 if (!en_add) {
128                         delete ns;
129                 } else {
130                         this->nodes().push_back(ns);
131                         this->add_iy(ns);
132                         pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
133                         if (this->collide(pn, ns)) {
134                                 pn->children().pop_back();
135                                 ns->remove_parent();
136                                 this->iy_[IYI(ns->y())].pop_back();
137                                 en_add = false;
138                         } else {
139                                 this->ocost(ns);
140                                 pn = ns;
141                                 newly_added.push_back(pn);
142                                 if (this->goal_found(pn, &KUWATA2008_DCOST)) {
143                                         this->tlog(this->findt());
144                                         en_add = false;
145                                 }
146                         }
147                 }
148         }
149         if (this->samples().size() <= 1)
150                 return this->goal_found();
151         for (auto na: newly_added) {
152                 pn = na;
153                 en_add = true;
154                 for (auto ns: this->steer(na, this->goal())) {
155                         if (!en_add) {
156                                 delete ns;
157                         } else {
158                                 this->nodes().push_back(ns);
159                                 this->add_iy(ns);
160                                 pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
161                                 if (this->collide(pn, ns)) {
162                                         pn->children().pop_back();
163                                         ns->remove_parent();
164                                         this->iy_[IYI(ns->y())].pop_back();
165                                         en_add = false;
166                                 } else {
167                                         this->ocost(ns);
168                                         pn = ns;
169                                         if (this->goal_found(pn,
170                                                         &KUWATA2008_DCOST)) {
171                                                 this->tlog(this->findt());
172                                                 en_add = false;
173                                         }
174                                 }
175                         }
176                 }
177         }
178         return this->goal_found();
179 }
180
181 Karaman2011::Karaman2011(RRTNode *init, RRTNode *goal):
182         RRTBase(init, goal),
183         nn(NN),
184         nv(NV),
185         sample(SA),
186         steer(ST),
187         cost(CO)
188 {
189         srand(static_cast<unsigned>(time(0)));
190 }
191
192 bool Karaman2011::next()
193 {
194         RRTNode *rs;
195 #if GOALFIRST > 0
196         if (this->samples().size() == 0)
197                 rs = this->goal();
198         else
199                 rs = this->sample();
200 #else
201         rs = this->sample();
202 #endif
203         this->samples().push_back(rs);
204 #if NNVERSION>1
205         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
206 #else
207         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
208 #endif
209         RRTNode *pn = nn;
210         std::vector<RRTNode *> nvs;
211         bool en_add = true;
212         for (auto ns: this->steer(nn, rs)) {
213                 if (!en_add) {
214                         delete ns;
215                 } else {
216 #if NVVERSION>1
217                         nvs = this->nv(
218                                         this->iy_,
219                                         ns,
220                                         this->cost,
221                                         MIN(
222                                                 GAMMA_RRTSTAR(
223                                                         this->nodes().size()),
224                                                 0.2)); // TODO const
225 #else
226                         nvs = this->nv(
227                                         this->root(),
228                                         ns,
229                                         this->cost,
230                                         MIN(
231                                                 GAMMA_RRTSTAR(
232                                                         this->nodes().size()),
233                                                 0.2)); // TODO const
234 #endif
235                         this->nodes().push_back(ns);
236                         this->add_iy(ns);
237                         // connect
238                         if (!this->connect(pn, ns, nvs)) {
239                                 this->iy_[IYI(ns->y())].pop_back();
240                                 en_add = false;
241                         } else {
242                                 // rewire
243                                 this->rewire(nvs, ns);
244                                 pn = ns;
245                                 if (this->goal_found(pn, this->cost)) {
246                                         this->tlog(this->findt());
247                                         en_add = false;
248                                 }
249                         }
250                 }
251         }
252         return this->goal_found();
253 }
254
255 bool Karaman2011::connect(
256                 RRTNode *pn,
257                 RRTNode *ns,
258                 std::vector<RRTNode *> nvs)
259 {
260         RRTNode *op; // old parent
261         float od; // old direct cost
262         float oc; // old cumulative cost
263         bool connected = false;
264         pn->add_child(ns, this->cost(pn, ns));
265         if (this->collide(pn, ns)) {
266                 pn->children().pop_back();
267                 ns->remove_parent();
268         } else {
269                 this->ocost(ns);
270                 connected = true;
271         }
272         for (auto nv: nvs) {
273                 if (!connected || (nv->ccost() + this->cost(nv, ns) <
274                                 ns->ccost())) {
275                         op = ns->parent();
276                         od = ns->dcost();
277                         oc = ns->ccost();
278                         nv->add_child(ns, this->cost(nv, ns));
279                         if (this->collide(nv, ns)) {
280                                 nv->children().pop_back();
281                                 if (op)
282                                         ns->parent(op);
283                                 else
284                                         ns->remove_parent();
285                                 ns->dcost(od);
286                                 ns->ccost(oc);
287                         } else if (connected) {
288                                 op->children().pop_back();
289                         } else {
290                                 this->ocost(ns);
291                                 connected = true;
292                         }
293                 }
294         }
295         return connected;
296 }
297
298 bool Karaman2011::rewire(std::vector<RRTNode *> nvs, RRTNode *ns)
299 {
300         RRTNode *op; // old parent
301         float od; // old direct cost
302         float oc; // old cumulative cost
303         for (auto nv: nvs) {
304                 if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
305                         op = nv->parent();
306                         od = nv->dcost();
307                         oc = nv->ccost();
308                         ns->add_child(nv, this->cost(ns, nv));
309                         if (this->collide(ns, nv)) {
310                                 ns->children().pop_back();
311                                 nv->parent(op);
312                                 nv->dcost(od);
313                                 nv->ccost(oc);
314                         } else {
315                                 op->rem_child(nv);
316                         }
317                 }
318         }
319         return true;
320 }
321
322 T1::T1(RRTNode *init, RRTNode *goal):
323         RRTBase(init, goal),
324         nn(NN),
325         nv(NV),
326         sample(SA),
327         steer(ST),
328         cost(CO)
329 {
330         srand(static_cast<unsigned>(time(0)));
331 }
332
333 bool T1::next()
334 {
335         RRTNode *rs;
336         if (this->samples().size() == 0)
337                 rs = this->goal();
338         else
339                 rs = this->sample();
340         this->samples().push_back(rs);
341 #if NNVERSION>1
342         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
343 #else
344         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
345 #endif
346         RRTNode *pn = nn;
347         std::vector<RRTNode *> nvs;
348         bool connected;
349         RRTNode *op; // old parent
350         float od; // old direct cost
351         float oc; // old cumulative cost
352         std::vector<RRTNode *> steered = this->steer(nn, rs);
353         // RRT* for first node
354         RRTNode *ns = steered[0];
355         {
356 #if NVVERSION>1
357                 nvs = this->nv(this->iy_, ns, this->cost, MIN(
358                                         GAMMA_RRTSTAR(this->nodes().size()),
359                                         0.2)); // TODO const
360 #else
361                 nvs = this->nv(this->root(), ns, this->cost, MIN(
362                                         GAMMA_RRTSTAR(this->nodes().size()),
363                                         0.2)); // TODO const
364 #endif
365                 this->nodes().push_back(ns);
366                 this->add_iy(ns);
367                 connected = false;
368                 pn->add_child(ns, this->cost(pn, ns));
369                 if (this->collide(pn, ns)) {
370                         pn->children().pop_back();
371                 } else {
372                         connected = true;
373                 }
374                 // connect
375                 for (auto nv: nvs) {
376                         if (!connected || (nv->ccost() + this->cost(nv, ns) <
377                                         ns->ccost())) {
378                                 op = ns->parent();
379                                 od = ns->dcost();
380                                 oc = ns->ccost();
381                                 nv->add_child(ns, this->cost(nv, ns));
382                                 if (this->collide(nv, ns)) {
383                                         nv->children().pop_back();
384                                         ns->parent(op);
385                                         ns->dcost(od);
386                                         ns->ccost(oc);
387                                 } else if (connected) {
388                                         op->children().pop_back();
389                                 } else {
390                                         connected = true;
391                                 }
392                         }
393                 }
394                 if (!connected)
395                         return false;
396                 // rewire
397                 for (auto nv: nvs) {
398                         if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
399                                 op = nv->parent();
400                                 od = nv->dcost();
401                                 oc = nv->ccost();
402                                 ns->add_child(nv, this->cost(ns, nv));
403                                 if (this->collide(ns, nv)) {
404                                         ns->children().pop_back();
405                                         nv->parent(op);
406                                         nv->dcost(od);
407                                         nv->ccost(oc);
408                                 } else {
409                                         op->rem_child(nv);
410                                 }
411                         }
412                 }
413                 pn = ns;
414                 if (this->goal_found(pn, this->cost)) {
415                         this->tlog(this->findt());
416                 }
417         }
418         unsigned int i = 0;
419         for (i = 1; i < steered.size(); i++) {
420                 ns = steered[i];
421                 this->nodes().push_back(ns);
422                 this->add_iy(ns);
423                 pn->add_child(ns, this->cost(pn, ns));
424                 if (this->collide(pn, ns)) {
425                         pn->children().pop_back();
426                         break;
427                 }
428                 pn = ns;
429                 if (this->goal_found(pn, this->cost)) {
430                         this->tlog(this->findt());
431                         break;
432                 }
433         }
434         return this->goal_found();
435 }
436
437 bool T2::next()
438 {
439         RRTNode *rs;
440 #if GOALFIRST > 0
441         if (this->samples().size() == 0)
442                 rs = this->goal();
443         else
444                 rs = this->sample();
445 #else
446         rs = this->sample();
447 #endif
448         this->samples().push_back(rs);
449 #if NNVERSION>1
450         RRTNode *nn = this->nn(this->iy_, rs, this->cost);
451 #else
452         RRTNode *nn = this->nn(this->nodes(), rs, this->cost);
453 #endif
454         RRTNode *pn = nn;
455         std::vector<RRTNode *> nvs;
456         std::vector<RRTNode *> newly_added;
457         bool en_add = true;
458         int cusps = 0;
459         for (auto ns: this->steer(nn, rs)) {
460                 if (!en_add) {
461                         delete ns;
462                 } else if (IS_NEAR(pn, ns)) {
463                         delete ns;
464                 } else {
465                         if (sgn(pn->s()) != sgn(ns->s()))
466                                 cusps++;
467                         if (cusps > 0)
468                                 en_add = false;
469 #if NVVERSION>1
470                         nvs = this->nv(
471                                         this->iy_,
472                                         ns,
473                                         this->cost,
474                                         MIN(
475                                                 GAMMA_RRTSTAR(
476                                                         this->nodes().size()),
477                                                 0.2)); // TODO const
478 #else
479                         nvs = this->nv(
480                                         this->root(),
481                                         ns,
482                                         this->cost,
483                                         MIN(
484                                                 GAMMA_RRTSTAR(
485                                                         this->nodes().size()),
486                                                 0.2)); // TODO const
487 #endif
488                         this->nodes().push_back(ns);
489                         this->add_iy(ns);
490                         // connect
491                         if (!this->connect(pn, ns, nvs)) {
492                                 this->iy_[IYI(ns->y())].pop_back();
493                                 en_add = false;
494                         } else {
495                                 // rewire
496                                 this->rewire(nvs, ns);
497                                 pn = ns;
498                                 newly_added.push_back(pn);
499                                 if (this->goal_found(pn, this->cost)) {
500                                         this->goal_cost();
501                                         this->tlog(this->findt());
502                                         en_add = false;
503                                 }
504                         }
505                 }
506         }
507         if (this->samples().size() <= 1)
508                 return this->goal_found();
509         for (auto na: newly_added) {
510                 pn = na;
511                 en_add = true;
512                 cusps = 0;
513                 for (auto ns: this->steer(na, this->goal())) {
514                         if (!en_add) {
515                                 delete ns;
516                         } else if (IS_NEAR(pn, ns)) {
517                                 delete ns;
518                         } else {
519                                 if (sgn(pn->s()) != sgn(ns->s()))
520                                         cusps++;
521                                 if (cusps > 0)
522                                         en_add = false;
523                                 this->nodes().push_back(ns);
524                                 this->add_iy(ns);
525                                 pn->add_child(ns, this->cost(pn, ns));
526                                 if (this->collide(pn, ns)) {
527                                         pn->children().pop_back();
528                                         ns->remove_parent();
529                                         this->iy_[IYI(ns->y())].pop_back();
530                                         en_add = false;
531                                 } else {
532                                         this->ocost(ns);
533                                         pn = ns;
534                                         if (this->goal_found(pn, this->cost)) {
535                                                 this->goal_cost();
536                                                 this->tlog(this->findt());
537                                                 en_add = false;
538                                         }
539                                 }
540                         }
541                 }
542         }
543         return this->goal_found();
544 }
545
546 float T2::goal_cost()
547 {
548         std::vector<RRTNode *> nvs;
549 #if NVVERSION>1
550         nvs = this->nv(
551                         this->iy_,
552                         this->goal(),
553                         this->cost,
554                         0.2);
555 #else
556         nvs = this->nv(
557                         this->root(),
558                         this->goal(),
559                         this->cost,
560                         0.2);
561 #endif
562         for (auto nv: nvs) {
563                 if (std::abs(this->goal()->h() - nv->h()) >=
564                                 this->GOAL_FOUND_ANGLE)
565                         continue;
566                 if (nv->ccost() + (*cost)(nv, this->goal()) >=
567                                 this->goal()->ccost())
568                         continue;
569                 RRTNode *op; // old parent
570                 float oc; // old cumulative cost
571                 float od; // old direct cost
572                 op = this->goal()->parent();
573                 oc = this->goal()->ccost();
574                 od = this->goal()->dcost();
575                 nv->add_child(this->goal(),
576                                 (*cost)(nv, this->goal()));
577                 if (this->collide(nv, this->goal())) {
578                         nv->children().pop_back();
579                         this->goal()->parent(op);
580                         this->goal()->ccost(oc);
581                         this->goal()->dcost(od);
582                 } else {
583                         op->rem_child(this->goal());
584                 }
585         }
586         return this->goal()->ccost();
587 }