]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - base/rrtbase.cc
8a05b3e0087c0570bcf4ee250184f4da5a567123
[hubacji1/iamcar.git] / base / rrtbase.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <algorithm>
19 #include <cmath>
20 #include <omp.h>
21 #include <queue>
22 #include "bcar.h"
23 #include "rrtbase.h"
24 // OpenGL
25 #include <GL/gl.h>
26 #include <GL/glu.h>
27 #include <SDL2/SDL.h>
28 // RRT
29 #include "sample.h"
30 #include "cost.h"
31 #include "steer.h"
32
33 extern SDL_Window* gw;
34 extern SDL_GLContext gc;
35
36 RRTBase::~RRTBase()
37 {
38         for (auto n: this->nodes_)
39                 if (n != this->root_)
40                         delete n;
41         for (auto n: this->dnodes_)
42                 if (n != this->root_ && n != this->goal_)
43                         delete n;
44         for (auto s: this->samples_)
45                 if (s != this->goal_)
46                         delete s;
47         for (auto edges: this->rlog_)
48                 for (auto e: edges)
49                         delete e;
50         delete this->root_;
51         delete this->goal_;
52 }
53
54 RRTBase::RRTBase():
55         root_(new RRTNode()),
56         goal_(new RRTNode())
57 {
58         this->nodes_.push_back(this->root_);
59         this->add_iy(this->root_);
60 }
61
62 RRTBase::RRTBase(RRTNode *init, RRTNode *goal):
63         root_(init),
64         goal_(goal)
65 {
66         this->nodes_.push_back(init);
67         this->add_iy(init);
68 }
69
70 // getter
71 RRTNode *RRTBase::root()
72 {
73         return this->root_;
74 }
75
76 RRTNode *RRTBase::goal()
77 {
78         return this->goal_;
79 }
80
81 std::vector<RRTNode *> &RRTBase::nodes()
82 {
83         return this->nodes_;
84 }
85
86 std::vector<RRTNode *> &RRTBase::dnodes()
87 {
88         return this->dnodes_;
89 }
90
91 std::vector<RRTNode *> &RRTBase::samples()
92 {
93         return this->samples_;
94 }
95
96 std::vector<CircleObstacle> *RRTBase::cos()
97 {
98         return this->cobstacles_;
99 }
100
101 std::vector<SegmentObstacle> *RRTBase::sos()
102 {
103         return this->sobstacles_;
104 }
105
106 std::vector<float> &RRTBase::clog()
107 {
108         return this->clog_;
109 }
110
111 std::vector<float> &RRTBase::nlog()
112 {
113         return this->nlog_;
114 }
115
116 std::vector<std::vector<RRTEdge *>> &RRTBase::rlog()
117 {
118         return this->rlog_;
119 }
120
121 std::vector<float> &RRTBase::slog()
122 {
123         return this->slog_;
124 }
125
126 std::vector<std::vector<RRTNode *>> &RRTBase::tlog()
127 {
128         return this->tlog_;
129 }
130
131 bool RRTBase::goal_found()
132 {
133         return this->goal_found_;
134 }
135
136 float RRTBase::elapsed()
137 {
138         std::chrono::duration<float> dt;
139         dt = std::chrono::duration_cast<std::chrono::duration<float>>(
140                         this->tend_ - this->tstart_);
141         return dt.count();
142 }
143
144 // setter
145 void RRTBase::root(RRTNode *node)
146 {
147         this->root_ = node;
148 }
149
150 void RRTBase::goal(RRTNode *node)
151 {
152         this->goal_ = node;
153 }
154
155 bool RRTBase::logr(RRTNode *root)
156 {
157         std::vector<RRTEdge *> e; // Edges to log
158         std::vector<RRTNode *> s; // DFS stack
159         std::vector<RRTNode *> r; // reset visited_
160         RRTNode *tmp;
161         s.push_back(root);
162         while (s.size() > 0) {
163                 tmp = s.back();
164                 s.pop_back();
165                 if (!tmp->visit()) {
166                         r.push_back(tmp);
167                         for (auto ch: tmp->children()) {
168                                 s.push_back(ch);
169                                 e.push_back(new RRTEdge(tmp, ch));
170                         }
171                 }
172         }
173         for (auto n: r)
174                 n->visit(false);
175         this->rlog_.push_back(e);
176         return true;
177 }
178
179 float RRTBase::ocost(RRTNode *n)
180 {
181         float dist = 9999;
182         for (auto o: *this->cobstacles_)
183                 if (o.dist_to(n) < dist)
184                         dist = o.dist_to(n);
185         for (auto o: *this->sobstacles_)
186                 if (o.dist_to(n) < dist)
187                         dist = o.dist_to(n);
188         return n->ocost(dist);
189 }
190
191 bool RRTBase::tlog(std::vector<RRTNode *> t)
192 {
193         if (t.size() > 0) {
194                 this->slog_.push_back(this->elapsed());
195                 this->clog_.push_back(t.front()->ccost() - t.back()->ccost());
196                 this->nlog_.push_back(this->nodes_.size());
197                 this->tlog_.push_back(t);
198                 return true;
199         } else {
200                 return false;
201         }
202 }
203
204 void RRTBase::tstart()
205 {
206         this->tstart_ = std::chrono::high_resolution_clock::now();
207 }
208
209 void RRTBase::tend()
210 {
211         this->tend_ = std::chrono::high_resolution_clock::now();
212 }
213
214 bool RRTBase::link_obstacles(
215                 std::vector<CircleObstacle> *cobstacles,
216                 std::vector<SegmentObstacle> *sobstacles)
217 {
218         this->cobstacles_ = cobstacles;
219         this->sobstacles_ = sobstacles;
220         if (!this->cobstacles_ || !this->sobstacles_) {
221                 return false;
222         }
223         return true;
224 }
225
226 bool RRTBase::add_iy(RRTNode *n)
227 {
228         int i = IYI(n->y());
229         if (i < 0)
230                 i = 0;
231         if (i >= IYSIZE)
232                 i = IYSIZE - 1;
233         this->iy_[i].push_back(n);
234         return true;
235 }
236
237 bool RRTBase::goal_found(bool f)
238 {
239         this->goal_found_ = f;
240         return f;
241 }
242
243 // other
244 bool RRTBase::glplot()
245 {
246         glClear(GL_COLOR_BUFFER_BIT);
247         glLineWidth(1);
248         glPointSize(1);
249         // Plot obstacles
250         glBegin(GL_LINES);
251         for (auto o: *this->sobstacles_) {
252                 glColor3f(0, 0, 0);
253                 glVertex2f(GLVERTEX(o.init()));
254                 glVertex2f(GLVERTEX(o.goal()));
255         }
256         glEnd();
257         // Plot root, goal
258         glPointSize(8);
259         glBegin(GL_POINTS);
260         glColor3f(1, 0, 0);
261         glVertex2f(GLVERTEX(this->root_));
262         glVertex2f(GLVERTEX(this->goal_));
263         glEnd();
264         // Plot last sample
265         if (this->samples_.size() > 0) {
266                 glPointSize(8);
267                 glBegin(GL_POINTS);
268                 glColor3f(0, 1, 0);
269                 glVertex2f(GLVERTEX(this->samples_.back()));
270                 glEnd();
271         }
272         // Plot nodes
273         std::vector<RRTNode *> s; // DFS stack
274         std::vector<RRTNode *> r; // reset visited_
275         RRTNode *tmp;
276         glBegin(GL_LINES);
277         s.push_back(this->root_);
278         while (s.size() > 0) {
279                 tmp = s.back();
280                 s.pop_back();
281                 if (!tmp->visit()) {
282                         r.push_back(tmp);
283                         for (auto ch: tmp->children()) {
284                                 s.push_back(ch);
285                                 glColor3f(0.5, 0.5, 0.5);
286                                 glVertex2f(GLVERTEX(tmp));
287                                 glVertex2f(GLVERTEX(ch));
288                         }
289                 }
290         }
291         glEnd();
292         // Plot nodes (from goal)
293         glBegin(GL_LINES);
294         s.push_back(this->goal_);
295         while (s.size() > 0) {
296                 tmp = s.back();
297                 s.pop_back();
298                 if (!tmp->visit()) {
299                         r.push_back(tmp);
300                         for (auto ch: tmp->children()) {
301                                 s.push_back(ch);
302                                 glColor3f(0.5, 0.5, 0.5);
303                                 glVertex2f(GLVERTEX(tmp));
304                                 glVertex2f(GLVERTEX(ch));
305                         }
306                 }
307         }
308         glEnd();
309         std::vector<RRTNode *> cusps;
310         // Plot last trajectory
311         if (this->tlog().size() > 0) {
312                 glLineWidth(2);
313                 glBegin(GL_LINES);
314                 for (auto n: this->tlog().back()) {
315                         if (n->parent()) {
316                                 glColor3f(0, 0, 1);
317                                 glVertex2f(GLVERTEX(n));
318                                 glVertex2f(GLVERTEX(n->parent()));
319                                 if (sgn(n->s()) != sgn(n->parent()->s()))
320                                         cusps.push_back(n);
321                         }
322                 }
323                 glEnd();
324         }
325         // Plot cusps
326         glPointSize(8);
327         glBegin(GL_POINTS);
328         for (auto n: cusps) {
329                 glColor3f(0, 0, 1);
330                 glVertex2f(GLVERTEX(n));
331         }
332         glEnd();
333         SDL_GL_SwapWindow(gw);
334         for (auto n: r)
335                 n->visit(false);
336         return true;
337 }
338
339 bool RRTBase::goal_found(
340                 RRTNode *node,
341                 float (*cost)(RRTNode *, RRTNode* ))
342 {
343         float xx = pow(node->x() - this->goal_->x(), 2);
344         float yy = pow(node->y() - this->goal_->y(), 2);
345         float dh = std::abs(node->h() - this->goal_->h());
346         if (IS_NEAR(node, this->goal_)) {
347                 if (this->goal_found_) {
348                         if (node->ccost() + (*cost)(node, this->goal_) <
349                                         this->goal_->ccost()) {
350                                 RRTNode *op; // old parent
351                                 float oc; // old cumulative cost
352                                 float od; // old direct cost
353                                 op = this->goal_->parent();
354                                 oc = this->goal_->ccost();
355                                 od = this->goal_->dcost();
356                                 node->add_child(this->goal_,
357                                                 (*cost)(node, this->goal_));
358                                 if (this->collide(node, this->goal_)) {
359                                         node->children().pop_back();
360                                         this->goal_->parent(op);
361                                         this->goal_->ccost(oc);
362                                         this->goal_->dcost(od);
363                                 } else {
364                                         op->rem_child(this->goal_);
365                                         return true;
366                                 }
367                         } else {
368                                 return false;
369                         }
370                 } else {
371                         node->add_child(
372                                         this->goal_,
373                                         (*cost)(node, this->goal_));
374                         if (this->collide(node, this->goal_)) {
375                                 node->children().pop_back();
376                                 this->goal_->remove_parent();
377                                 return false;
378                         }
379                         this->goal_found_ = true;
380                         return true;
381                 }
382         }
383         return false;
384 }
385
386 bool RRTBase::collide(RRTNode *init, RRTNode *goal)
387 {
388         std::vector<RRTEdge *> edges;
389         RRTNode *tmp = goal;
390         volatile bool col = false;
391         unsigned int i;
392         while (tmp != init) {
393                 BicycleCar bc(tmp->x(), tmp->y(), tmp->h());
394                 std::vector<RRTEdge *> bcframe = bc.frame();
395                 #pragma omp parallel for reduction(|: col)
396                 for (i = 0; i < (*this->cobstacles_).size(); i++) {
397                         if ((*this->cobstacles_)[i].collide(tmp)) {
398                                 col = true;
399                         }
400                         for (auto &e: bcframe) {
401                                 if ((*this->cobstacles_)[i].collide(e)) {
402                                         col = true;
403                                 }
404                         }
405                 }
406                 if (col) {
407                         for (auto e: bcframe) {
408                                 delete e->init();
409                                 delete e->goal();
410                                 delete e;
411                         }
412                         for (auto e: edges) {
413                                 delete e;
414                         }
415                         return true;
416                 }
417                 #pragma omp parallel for reduction(|: col)
418                 for (i = 0; i < (*this->sobstacles_).size(); i++) {
419                         for (auto &e: bcframe) {
420                                 if ((*this->sobstacles_)[i].collide(e)) {
421                                         col = true;
422                                 }
423                         }
424                 }
425                 if (col) {
426                         for (auto e: bcframe) {
427                                 delete e->init();
428                                 delete e->goal();
429                                 delete e;
430                         }
431                         for (auto e: edges) {
432                                 delete e;
433                         }
434                         return true;
435                 }
436                 if (!tmp->parent()) {
437                         break;
438                 }
439                 edges.push_back(new RRTEdge(tmp, tmp->parent()));
440                 tmp = tmp->parent();
441                 for (auto e: bcframe) {
442                         delete e->init();
443                         delete e->goal();
444                         delete e;
445                 }
446         }
447         for (auto &e: edges) {
448                 #pragma omp parallel for reduction(|: col)
449                 for (i = 0; i < (*this->cobstacles_).size(); i++) {
450                         if ((*this->cobstacles_)[i].collide(e)) {
451                                 col = true;
452                         }
453                 }
454                 if (col) {
455                         for (auto e: edges) {
456                                 delete e;
457                         }
458                         return true;
459                 }
460                 #pragma omp parallel for reduction(|: col)
461                 for (i = 0; i < (*this->sobstacles_).size(); i++) {
462                         if ((*this->sobstacles_)[i].collide(e)) {
463                                 col = true;
464                         }
465                 }
466                 if (col) {
467                         for (auto e: edges) {
468                                 delete e;
469                         }
470                         return true;
471                 }
472         }
473         for (auto e: edges) {
474                 delete e;
475         }
476         return false;
477 }
478
479 class RRTNodeDijkstra {
480         public:
481                 RRTNodeDijkstra(int i):
482                         ni(i),
483                         pi(0),
484                         c(9999),
485                         v(false)
486                 {};
487                 RRTNodeDijkstra(int i, float c):
488                         ni(i),
489                         pi(0),
490                         c(c),
491                         v(false)
492                 {};
493                 RRTNodeDijkstra(int i, int p, float c):
494                         ni(i),
495                         pi(p),
496                         c(c),
497                         v(false)
498                 {};
499                 unsigned int ni;
500                 unsigned int pi;
501                 float c;
502                 bool v;
503                 bool vi()
504                 {
505                         if (this->v)
506                                 return true;
507                         this->v = true;
508                         return false;
509                 };
510 };
511
512 class RRTNodeDijkstraComparator {
513         public:
514                 int operator() (
515                                 const RRTNodeDijkstra& n1,
516                                 const RRTNodeDijkstra& n2)
517                 {
518                         return n1.c > n2.c;
519                 }
520 };
521
522 bool RRTBase::opt_path()
523 {
524         if (this->tlog().size() == 0)
525                 return false;
526         float oc = this->tlog().back().front()->ccost();
527         std::vector<RRTNode *> tmp_cusps;
528         for (auto n: this->tlog().back()) {
529                 if (sgn(n->s()) == 0) {
530                         tmp_cusps.push_back(n);
531                 } else if (n->parent() &&
532                                 sgn(n->s()) != sgn(n->parent()->s())) {
533                         tmp_cusps.push_back(n);
534                         tmp_cusps.push_back(n->parent());
535                 }
536         }
537         if (tmp_cusps.size() < 2)
538                 return false;
539         std::vector<RRTNode *> cusps;
540         for (unsigned int i = 0; i < tmp_cusps.size(); i++) {
541                 if (tmp_cusps[i] != tmp_cusps[i + 1])
542                         cusps.push_back(tmp_cusps[i]);
543         }
544         std::reverse(cusps.begin(), cusps.end());
545         // Begin of Dijkstra
546         std::vector<RRTNodeDijkstra> dnodes;
547         for (unsigned int i = 0; i < cusps.size(); i++)
548                 if (i > 0)
549                         dnodes.push_back(RRTNodeDijkstra(
550                                                 i,
551                                                 i - 1,
552                                                 cusps[i]->ccost()));
553                 else
554                         dnodes.push_back(RRTNodeDijkstra(
555                                                 i,
556                                                 cusps[i]->ccost()));
557         dnodes[0].vi();
558         std::priority_queue<
559                 RRTNodeDijkstra,
560                 std::vector<RRTNodeDijkstra>,
561                 RRTNodeDijkstraComparator> pq;
562         RRTNodeDijkstra tmp = dnodes[0];
563         pq.push(tmp);
564         float ch_cost = 9999;
565         std::vector<RRTNode *> steered;
566         while (!pq.empty() && tmp.ni != cusps.size() - 1) {
567                 tmp = pq.top();
568                 pq.pop();
569                 for (unsigned int i = tmp.ni + 1; i < cusps.size(); i++) {
570                         ch_cost = dnodes[tmp.ni].c +
571                                 CO(cusps[tmp.ni], cusps[i]);
572                         steered = ST(cusps[tmp.ni], cusps[i]);
573                         for (unsigned int j = 0; j < steered.size() - 1; j++)
574                                 steered[j]->add_child(
575                                                 steered[j + 1],
576                                                 CO(
577                                                         steered[j],
578                                                         steered[j + 1]));
579                         if (i != tmp.ni + 1 && this->collide( // TODO
580                                                 steered[0],
581                                                 steered[steered.size() - 1]))
582                                 continue;
583                         if (ch_cost < dnodes[i].c) {
584                                 dnodes[i].c = ch_cost;
585                                 dnodes[i].pi = tmp.ni;
586                                 if (!dnodes[i].vi())
587                                         pq.push(dnodes[i]);
588                         }
589                 }
590         }
591         if (tmp.ni != cusps.size() - 1)
592                 return false;
593         std::vector<int> npi; // new path indexes
594         int tmpi = tmp.ni;
595         while (tmpi > 0) {
596                 npi.push_back(tmpi);
597                 tmpi = dnodes[tmpi].pi;
598         }
599         npi.push_back(tmpi);
600         std::reverse(npi.begin(), npi.end());
601         RRTNode *pn = cusps[npi[0]];
602         for (unsigned int i = 0; i < npi.size() - 1; i++) {
603                 for (auto ns: ST(cusps[npi[i]], cusps[npi[i + 1]])) {
604                         pn->add_child(ns, CO(pn, ns));
605                         pn = ns;
606                 }
607         }
608         pn->add_child(
609                         this->tlog().back().front(),
610                         CO(pn, this->tlog().back().front()));
611         // End of Dijkstra
612         if (this->tlog().back().front()->ccost() < oc)
613                 return true;
614         return false;
615 }
616
617 bool RRTBase::rebase(RRTNode *nr)
618 {
619         if (!nr || this->goal_ == nr || this->root_ == nr)
620                 return false;
621         std::vector<RRTNode *> s; // DFS stack
622         RRTNode *tmp;
623         unsigned int i = 0;
624         unsigned int to_del = 0;
625         int iy = 0;
626         s.push_back(this->root_);
627         while (s.size() > 0) {
628                 tmp = s.back();
629                 s.pop_back();
630                 for (auto ch: tmp->children()) {
631                         if (ch != nr)
632                                 s.push_back(ch);
633                 }
634                 to_del = this->nodes_.size();
635                 #pragma omp parallel for reduction(min: to_del)
636                 for (i = 0; i < this->nodes_.size(); i++) {
637                         if (this->nodes_[i] == tmp)
638                                 to_del = i;
639                 }
640                 if (to_del < this->nodes_.size())
641                         this->nodes_.erase(this->nodes_.begin() + to_del);
642 #if NNVERSION > 1
643                 iy = IYI(tmp->y());
644                 to_del = this->iy_[iy].size();
645                 #pragma omp parallel  for reduction(min: to_del)
646                 for (i = 0; i < this->iy_[iy].size(); i++) {
647                         if (this->iy_[iy][i] == tmp)
648                                 to_del = i;
649                 }
650                 if (to_del < this->iy_[iy].size())
651                         this->iy_[iy].erase(this->iy_[iy].begin() + to_del);
652 #endif
653                 this->dnodes().push_back(tmp);
654         }
655         this->root_ = nr;
656         this->root_->remove_parent();
657         return true;
658 }
659
660 std::vector<RRTNode *> RRTBase::findt()
661 {
662         return this->findt(this->goal_);
663 }
664
665 std::vector<RRTNode *> RRTBase::findt(RRTNode *n)
666 {
667         std::vector<RRTNode *> nodes;
668         if (!n || !n->parent())
669                 return nodes;
670         RRTNode *tmp = n;
671         while (tmp != this->root()) {
672                 nodes.push_back(tmp);
673                 tmp = tmp->parent();
674         }
675         return nodes;
676 }
677
678 // RRT Framework
679 RRTNode *RRTBase::sample()
680 {
681         return sa1();
682 }
683
684 float RRTBase::cost(RRTNode *init, RRTNode *goal)
685 {
686         return co2(init, goal);
687 }