]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - base/rrtbase.cc
Add plot nodes from goal
[hubacji1/iamcar.git] / base / rrtbase.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <algorithm>
19 #include <cmath>
20 #include <omp.h>
21 #include <queue>
22 #include "bcar.h"
23 #include "rrtbase.h"
24 // OpenGL
25 #include <GL/gl.h>
26 #include <GL/glu.h>
27 #include <SDL2/SDL.h>
28 // RRT
29 #include "cost.h"
30 #include "steer.h"
31
32 extern SDL_Window* gw;
33 extern SDL_GLContext gc;
34
35 RRTBase::~RRTBase()
36 {
37         for (auto n: this->nodes_)
38                 if (n != this->root_)
39                         delete n;
40         for (auto n: this->dnodes_)
41                 if (n != this->root_ && n != this->goal_)
42                         delete n;
43         for (auto s: this->samples_)
44                 if (s != this->goal_)
45                         delete s;
46         for (auto edges: this->rlog_)
47                 for (auto e: edges)
48                         delete e;
49         delete this->root_;
50         delete this->goal_;
51 }
52
53 RRTBase::RRTBase():
54         root_(new RRTNode()),
55         goal_(new RRTNode())
56 {
57         this->nodes_.push_back(this->root_);
58         this->add_iy(this->root_);
59 }
60
61 RRTBase::RRTBase(RRTNode *init, RRTNode *goal):
62         root_(init),
63         goal_(goal)
64 {
65         this->nodes_.push_back(init);
66         this->add_iy(init);
67 }
68
69 // getter
70 RRTNode *RRTBase::root()
71 {
72         return this->root_;
73 }
74
75 RRTNode *RRTBase::goal()
76 {
77         return this->goal_;
78 }
79
80 std::vector<RRTNode *> &RRTBase::nodes()
81 {
82         return this->nodes_;
83 }
84
85 std::vector<RRTNode *> &RRTBase::dnodes()
86 {
87         return this->dnodes_;
88 }
89
90 std::vector<RRTNode *> &RRTBase::samples()
91 {
92         return this->samples_;
93 }
94
95 std::vector<CircleObstacle> *RRTBase::cos()
96 {
97         return this->cobstacles_;
98 }
99
100 std::vector<SegmentObstacle> *RRTBase::sos()
101 {
102         return this->sobstacles_;
103 }
104
105 std::vector<float> &RRTBase::clog()
106 {
107         return this->clog_;
108 }
109
110 std::vector<float> &RRTBase::nlog()
111 {
112         return this->nlog_;
113 }
114
115 std::vector<std::vector<RRTEdge *>> &RRTBase::rlog()
116 {
117         return this->rlog_;
118 }
119
120 std::vector<float> &RRTBase::slog()
121 {
122         return this->slog_;
123 }
124
125 std::vector<std::vector<RRTNode *>> &RRTBase::tlog()
126 {
127         return this->tlog_;
128 }
129
130 bool RRTBase::goal_found()
131 {
132         return this->goal_found_;
133 }
134
135 float RRTBase::elapsed()
136 {
137         std::chrono::duration<float> dt;
138         dt = std::chrono::duration_cast<std::chrono::duration<float>>(
139                         this->tend_ - this->tstart_);
140         return dt.count();
141 }
142
143 // setter
144 void RRTBase::root(RRTNode *node)
145 {
146         this->root_ = node;
147 }
148
149 void RRTBase::goal(RRTNode *node)
150 {
151         this->goal_ = node;
152 }
153
154 bool RRTBase::logr(RRTNode *root)
155 {
156         std::vector<RRTEdge *> e; // Edges to log
157         std::vector<RRTNode *> s; // DFS stack
158         std::vector<RRTNode *> r; // reset visited_
159         RRTNode *tmp;
160         s.push_back(root);
161         while (s.size() > 0) {
162                 tmp = s.back();
163                 s.pop_back();
164                 if (!tmp->visit()) {
165                         r.push_back(tmp);
166                         for (auto ch: tmp->children()) {
167                                 s.push_back(ch);
168                                 e.push_back(new RRTEdge(tmp, ch));
169                         }
170                 }
171         }
172         for (auto n: r)
173                 n->visit(false);
174         this->rlog_.push_back(e);
175         return true;
176 }
177
178 float RRTBase::ocost(RRTNode *n)
179 {
180         float dist = 9999;
181         for (auto o: *this->cobstacles_)
182                 if (o.dist_to(n) < dist)
183                         dist = o.dist_to(n);
184         for (auto o: *this->sobstacles_)
185                 if (o.dist_to(n) < dist)
186                         dist = o.dist_to(n);
187         return n->ocost(dist);
188 }
189
190 bool RRTBase::tlog(std::vector<RRTNode *> t)
191 {
192         if (t.size() > 0) {
193                 this->slog_.push_back(this->elapsed());
194                 this->clog_.push_back(t.front()->ccost() - t.back()->ccost());
195                 this->nlog_.push_back(this->nodes_.size());
196                 this->tlog_.push_back(t);
197                 return true;
198         } else {
199                 return false;
200         }
201 }
202
203 void RRTBase::tstart()
204 {
205         this->tstart_ = std::chrono::high_resolution_clock::now();
206 }
207
208 void RRTBase::tend()
209 {
210         this->tend_ = std::chrono::high_resolution_clock::now();
211 }
212
213 bool RRTBase::link_obstacles(
214                 std::vector<CircleObstacle> *cobstacles,
215                 std::vector<SegmentObstacle> *sobstacles)
216 {
217         this->cobstacles_ = cobstacles;
218         this->sobstacles_ = sobstacles;
219         if (!this->cobstacles_ || !this->sobstacles_) {
220                 return false;
221         }
222         return true;
223 }
224
225 bool RRTBase::add_iy(RRTNode *n)
226 {
227         int i = IYI(n->y());
228         if (i < 0)
229                 i = 0;
230         if (i >= IYSIZE)
231                 i = IYSIZE - 1;
232         this->iy_[i].push_back(n);
233         return true;
234 }
235
236 bool RRTBase::goal_found(bool f)
237 {
238         this->goal_found_ = f;
239         return f;
240 }
241
242 // other
243 bool RRTBase::glplot()
244 {
245         glClear(GL_COLOR_BUFFER_BIT);
246         glLineWidth(1);
247         glPointSize(1);
248         // Plot obstacles
249         glBegin(GL_LINES);
250         for (auto o: *this->sobstacles_) {
251                 glColor3f(0, 0, 0);
252                 glVertex2f(GLVERTEX(o.init()));
253                 glVertex2f(GLVERTEX(o.goal()));
254         }
255         glEnd();
256         // Plot root, goal
257         glPointSize(8);
258         glBegin(GL_POINTS);
259         glColor3f(1, 0, 0);
260         glVertex2f(GLVERTEX(this->root_));
261         glVertex2f(GLVERTEX(this->goal_));
262         glEnd();
263         // Plot last sample
264         if (this->samples_.size() > 0) {
265                 glPointSize(8);
266                 glBegin(GL_POINTS);
267                 glColor3f(0, 1, 0);
268                 glVertex2f(GLVERTEX(this->samples_.back()));
269                 glEnd();
270         }
271         // Plot nodes
272         std::vector<RRTNode *> s; // DFS stack
273         std::vector<RRTNode *> r; // reset visited_
274         RRTNode *tmp;
275         glBegin(GL_LINES);
276         s.push_back(this->root_);
277         while (s.size() > 0) {
278                 tmp = s.back();
279                 s.pop_back();
280                 if (!tmp->visit()) {
281                         r.push_back(tmp);
282                         for (auto ch: tmp->children()) {
283                                 s.push_back(ch);
284                                 glColor3f(0.5, 0.5, 0.5);
285                                 glVertex2f(GLVERTEX(tmp));
286                                 glVertex2f(GLVERTEX(ch));
287                         }
288                 }
289         }
290         glEnd();
291         // Plot nodes (from goal)
292         glBegin(GL_LINES);
293         s.push_back(this->goal_);
294         while (s.size() > 0) {
295                 tmp = s.back();
296                 s.pop_back();
297                 if (!tmp->visit()) {
298                         r.push_back(tmp);
299                         for (auto ch: tmp->children()) {
300                                 s.push_back(ch);
301                                 glColor3f(0.5, 0.5, 0.5);
302                                 glVertex2f(GLVERTEX(tmp));
303                                 glVertex2f(GLVERTEX(ch));
304                         }
305                 }
306         }
307         glEnd();
308         std::vector<RRTNode *> cusps;
309         // Plot last trajectory
310         if (this->tlog().size() > 0) {
311                 glLineWidth(2);
312                 glBegin(GL_LINES);
313                 for (auto n: this->tlog().back()) {
314                         if (n->parent()) {
315                                 glColor3f(0, 0, 1);
316                                 glVertex2f(GLVERTEX(n));
317                                 glVertex2f(GLVERTEX(n->parent()));
318                                 if (sgn(n->s()) != sgn(n->parent()->s()))
319                                         cusps.push_back(n);
320                         }
321                 }
322                 glEnd();
323         }
324         // Plot cusps
325         glPointSize(8);
326         glBegin(GL_POINTS);
327         for (auto n: cusps) {
328                 glColor3f(0, 0, 1);
329                 glVertex2f(GLVERTEX(n));
330         }
331         glEnd();
332         SDL_GL_SwapWindow(gw);
333         for (auto n: r)
334                 n->visit(false);
335         return true;
336 }
337
338 bool RRTBase::goal_found(
339                 RRTNode *node,
340                 float (*cost)(RRTNode *, RRTNode* ))
341 {
342         float xx = pow(node->x() - this->goal_->x(), 2);
343         float yy = pow(node->y() - this->goal_->y(), 2);
344         float dh = std::abs(node->h() - this->goal_->h());
345         if (IS_NEAR(node, this->goal_)) {
346                 if (this->goal_found_) {
347                         if (node->ccost() + (*cost)(node, this->goal_) <
348                                         this->goal_->ccost()) {
349                                 RRTNode *op; // old parent
350                                 float oc; // old cumulative cost
351                                 float od; // old direct cost
352                                 op = this->goal_->parent();
353                                 oc = this->goal_->ccost();
354                                 od = this->goal_->dcost();
355                                 node->add_child(this->goal_,
356                                                 (*cost)(node, this->goal_));
357                                 if (this->collide(node, this->goal_)) {
358                                         node->children().pop_back();
359                                         this->goal_->parent(op);
360                                         this->goal_->ccost(oc);
361                                         this->goal_->dcost(od);
362                                 } else {
363                                         op->rem_child(this->goal_);
364                                         return true;
365                                 }
366                         } else {
367                                 return false;
368                         }
369                 } else {
370                         node->add_child(
371                                         this->goal_,
372                                         (*cost)(node, this->goal_));
373                         if (this->collide(node, this->goal_)) {
374                                 node->children().pop_back();
375                                 this->goal_->remove_parent();
376                                 return false;
377                         }
378                         this->goal_found_ = true;
379                         return true;
380                 }
381         }
382         return false;
383 }
384
385 bool RRTBase::collide(RRTNode *init, RRTNode *goal)
386 {
387         std::vector<RRTEdge *> edges;
388         RRTNode *tmp = goal;
389         volatile bool col = false;
390         unsigned int i;
391         while (tmp != init) {
392                 BicycleCar bc(tmp->x(), tmp->y(), tmp->h());
393                 std::vector<RRTEdge *> bcframe = bc.frame();
394                 #pragma omp parallel for reduction(|: col)
395                 for (i = 0; i < (*this->cobstacles_).size(); i++) {
396                         if ((*this->cobstacles_)[i].collide(tmp)) {
397                                 col = true;
398                         }
399                         for (auto &e: bcframe) {
400                                 if ((*this->cobstacles_)[i].collide(e)) {
401                                         col = true;
402                                 }
403                         }
404                 }
405                 if (col) {
406                         for (auto e: bcframe) {
407                                 delete e->init();
408                                 delete e->goal();
409                                 delete e;
410                         }
411                         for (auto e: edges) {
412                                 delete e;
413                         }
414                         return true;
415                 }
416                 #pragma omp parallel for reduction(|: col)
417                 for (i = 0; i < (*this->sobstacles_).size(); i++) {
418                         for (auto &e: bcframe) {
419                                 if ((*this->sobstacles_)[i].collide(e)) {
420                                         col = true;
421                                 }
422                         }
423                 }
424                 if (col) {
425                         for (auto e: bcframe) {
426                                 delete e->init();
427                                 delete e->goal();
428                                 delete e;
429                         }
430                         for (auto e: edges) {
431                                 delete e;
432                         }
433                         return true;
434                 }
435                 if (!tmp->parent()) {
436                         break;
437                 }
438                 edges.push_back(new RRTEdge(tmp, tmp->parent()));
439                 tmp = tmp->parent();
440                 for (auto e: bcframe) {
441                         delete e->init();
442                         delete e->goal();
443                         delete e;
444                 }
445         }
446         for (auto &e: edges) {
447                 #pragma omp parallel for reduction(|: col)
448                 for (i = 0; i < (*this->cobstacles_).size(); i++) {
449                         if ((*this->cobstacles_)[i].collide(e)) {
450                                 col = true;
451                         }
452                 }
453                 if (col) {
454                         for (auto e: edges) {
455                                 delete e;
456                         }
457                         return true;
458                 }
459                 #pragma omp parallel for reduction(|: col)
460                 for (i = 0; i < (*this->sobstacles_).size(); i++) {
461                         if ((*this->sobstacles_)[i].collide(e)) {
462                                 col = true;
463                         }
464                 }
465                 if (col) {
466                         for (auto e: edges) {
467                                 delete e;
468                         }
469                         return true;
470                 }
471         }
472         for (auto e: edges) {
473                 delete e;
474         }
475         return false;
476 }
477
478 class RRTNodeDijkstra {
479         public:
480                 RRTNodeDijkstra(int i):
481                         ni(i),
482                         pi(0),
483                         c(9999),
484                         v(false)
485                 {};
486                 RRTNodeDijkstra(int i, float c):
487                         ni(i),
488                         pi(0),
489                         c(c),
490                         v(false)
491                 {};
492                 RRTNodeDijkstra(int i, int p, float c):
493                         ni(i),
494                         pi(p),
495                         c(c),
496                         v(false)
497                 {};
498                 unsigned int ni;
499                 unsigned int pi;
500                 float c;
501                 bool v;
502                 bool vi()
503                 {
504                         if (this->v)
505                                 return true;
506                         this->v = true;
507                         return false;
508                 };
509 };
510
511 class RRTNodeDijkstraComparator {
512         public:
513                 int operator() (
514                                 const RRTNodeDijkstra& n1,
515                                 const RRTNodeDijkstra& n2)
516                 {
517                         return n1.c > n2.c;
518                 }
519 };
520
521 bool RRTBase::opt_path()
522 {
523         if (this->tlog().size() == 0)
524                 return false;
525         float oc = this->tlog().back().front()->ccost();
526         std::vector<RRTNode *> tmp_cusps;
527         for (auto n: this->tlog().back()) {
528                 if (sgn(n->s()) == 0) {
529                         tmp_cusps.push_back(n);
530                 } else if (n->parent() &&
531                                 sgn(n->s()) != sgn(n->parent()->s())) {
532                         tmp_cusps.push_back(n);
533                         tmp_cusps.push_back(n->parent());
534                 }
535         }
536         if (tmp_cusps.size() < 2)
537                 return false;
538         std::vector<RRTNode *> cusps;
539         for (unsigned int i = 0; i < tmp_cusps.size(); i++) {
540                 if (tmp_cusps[i] != tmp_cusps[i + 1])
541                         cusps.push_back(tmp_cusps[i]);
542         }
543         std::reverse(cusps.begin(), cusps.end());
544         // Begin of Dijkstra
545         std::vector<RRTNodeDijkstra> dnodes;
546         for (unsigned int i = 0; i < cusps.size(); i++)
547                 if (i > 0)
548                         dnodes.push_back(RRTNodeDijkstra(
549                                                 i,
550                                                 i - 1,
551                                                 cusps[i]->ccost()));
552                 else
553                         dnodes.push_back(RRTNodeDijkstra(
554                                                 i,
555                                                 cusps[i]->ccost()));
556         dnodes[0].vi();
557         std::priority_queue<
558                 RRTNodeDijkstra,
559                 std::vector<RRTNodeDijkstra>,
560                 RRTNodeDijkstraComparator> pq;
561         RRTNodeDijkstra tmp = dnodes[0];
562         pq.push(tmp);
563         float ch_cost = 9999;
564         std::vector<RRTNode *> steered;
565         while (!pq.empty() && tmp.ni != cusps.size() - 1) {
566                 tmp = pq.top();
567                 pq.pop();
568                 for (unsigned int i = tmp.ni + 1; i < cusps.size(); i++) {
569                         ch_cost = dnodes[tmp.ni].c +
570                                 CO(cusps[tmp.ni], cusps[i]);
571                         steered = ST(cusps[tmp.ni], cusps[i]);
572                         for (unsigned int j = 0; j < steered.size() - 1; j++)
573                                 steered[j]->add_child(
574                                                 steered[j + 1],
575                                                 CO(
576                                                         steered[j],
577                                                         steered[j + 1]));
578                         if (i != tmp.ni + 1 && this->collide( // TODO
579                                                 steered[0],
580                                                 steered[steered.size() - 1]))
581                                 continue;
582                         if (ch_cost < dnodes[i].c) {
583                                 dnodes[i].c = ch_cost;
584                                 dnodes[i].pi = tmp.ni;
585                                 if (!dnodes[i].vi())
586                                         pq.push(dnodes[i]);
587                         }
588                 }
589         }
590         if (tmp.ni != cusps.size() - 1)
591                 return false;
592         std::vector<int> npi; // new path indexes
593         int tmpi = tmp.ni;
594         while (tmpi > 0) {
595                 npi.push_back(tmpi);
596                 tmpi = dnodes[tmpi].pi;
597         }
598         npi.push_back(tmpi);
599         std::reverse(npi.begin(), npi.end());
600         RRTNode *pn = cusps[npi[0]];
601         for (unsigned int i = 0; i < npi.size() - 1; i++) {
602                 for (auto ns: ST(cusps[npi[i]], cusps[npi[i + 1]])) {
603                         pn->add_child(ns, CO(pn, ns));
604                         pn = ns;
605                 }
606         }
607         pn->add_child(
608                         this->tlog().back().front(),
609                         CO(pn, this->tlog().back().front()));
610         // End of Dijkstra
611         if (this->tlog().back().front()->ccost() < oc)
612                 return true;
613         return false;
614 }
615
616 bool RRTBase::rebase(RRTNode *nr)
617 {
618         if (!nr || this->goal_ == nr || this->root_ == nr)
619                 return false;
620         std::vector<RRTNode *> s; // DFS stack
621         RRTNode *tmp;
622         unsigned int i = 0;
623         unsigned int to_del = 0;
624         int iy = 0;
625         s.push_back(this->root_);
626         while (s.size() > 0) {
627                 tmp = s.back();
628                 s.pop_back();
629                 for (auto ch: tmp->children()) {
630                         if (ch != nr)
631                                 s.push_back(ch);
632                 }
633                 to_del = this->nodes_.size();
634                 #pragma omp parallel for reduction(min: to_del)
635                 for (i = 0; i < this->nodes_.size(); i++) {
636                         if (this->nodes_[i] == tmp)
637                                 to_del = i;
638                 }
639                 if (to_del < this->nodes_.size())
640                         this->nodes_.erase(this->nodes_.begin() + to_del);
641 #if NNVERSION > 1
642                 iy = IYI(tmp->y());
643                 to_del = this->iy_[iy].size();
644                 #pragma omp parallel  for reduction(min: to_del)
645                 for (i = 0; i < this->iy_[iy].size(); i++) {
646                         if (this->iy_[iy][i] == tmp)
647                                 to_del = i;
648                 }
649                 if (to_del < this->iy_[iy].size())
650                         this->iy_[iy].erase(this->iy_[iy].begin() + to_del);
651 #endif
652                 this->dnodes().push_back(tmp);
653         }
654         this->root_ = nr;
655         this->root_->remove_parent();
656         return true;
657 }
658
659 std::vector<RRTNode *> RRTBase::findt()
660 {
661         return this->findt(this->goal_);
662 }
663
664 std::vector<RRTNode *> RRTBase::findt(RRTNode *n)
665 {
666         std::vector<RRTNode *> nodes;
667         if (!n || !n->parent())
668                 return nodes;
669         RRTNode *tmp = n;
670         while (tmp != this->root()) {
671                 nodes.push_back(tmp);
672                 tmp = tmp->parent();
673         }
674         return nodes;
675 }