]> rtime.felk.cvut.cz Git - hubacji1/rrts.git/blob - src/rrts.cc
Merge branch 'feature/rrt-node-types'
[hubacji1/rrts.git] / src / rrts.cc
1 #include <algorithm>
2 #include "rrts.h"
3
4 #include "reeds_shepp.h"
5
6 #define ETA 1.0 // for steer, nv
7 #define GAMMA(cV) ({ \
8         __typeof__ (cV) _cV = (cV); \
9         pow(log(_cV) / _cV, 1.0 / 3.0); \
10 })
11
12 template <typename T> int sgn(T val) {
13         return (T(0) < val) - (val < T(0));
14 }
15
16 RRTNode::RRTNode()
17 {
18 }
19
20 RRTNode::RRTNode(const BicycleCar &bc) : BicycleCar(bc)
21 {
22 }
23
24 Obstacle::Obstacle()
25 {
26 }
27
28 double RRTS::elapsed()
29 {
30         std::chrono::duration<double> dt;
31         dt = std::chrono::duration_cast<std::chrono::duration<double>>(
32                 std::chrono::high_resolution_clock::now()
33                 - this->tstart_
34         );
35         this->scnt_ = dt.count();
36         return this->scnt_;
37 }
38
39 bool RRTS::should_stop()
40 {
41         // the following counters must be updated, do not comment
42         this->icnt_++;
43         this->elapsed();
44         // decide the stop conditions (maybe comment some lines)
45         if (this->icnt_ > 999) return true;
46         if (this->scnt_ > 10) return true;
47         if (this->gf()) return true;
48         // but continue by default
49         return false;
50 }
51
52 // RRT procedures
53 std::tuple<bool, unsigned int, unsigned int>
54 RRTS::collide(std::vector<std::tuple<double, double>> &poly)
55 {
56         for (auto &o: this->obstacles())
57                 if (std::get<0>(::collide(poly, o.poly())))
58                         return ::collide(poly, o.poly());
59         return std::make_tuple(false, 0, 0);
60 }
61
62 std::tuple<bool, unsigned int, unsigned int>
63 RRTS::collide_steered_from(RRTNode &f)
64 {
65         std::vector<std::tuple<double, double>> s;
66         s.push_back(std::make_tuple(f.x(), f.y()));
67         for (auto &n: this->steered()) {
68                 s.push_back(std::make_tuple(n.lfx(), n.lfy()));
69                 s.push_back(std::make_tuple(n.lrx(), n.lry()));
70                 s.push_back(std::make_tuple(n.rrx(), n.rry()));
71                 s.push_back(std::make_tuple(n.rfx(), n.rfy()));
72         }
73         auto col = this->collide(s);
74         auto strip_from = this->steered().size() - std::get<1>(col) / 4;
75         if (std::get<0>(col) && strip_from > 0) {
76                 while (strip_from-- > 0) {
77                         this->steered().pop_back();
78                 }
79                 return this->collide_steered_from(f);
80         }
81         return col;
82 }
83
84 std::tuple<bool, unsigned int, unsigned int>
85 RRTS::collide_two_nodes(RRTNode &f, RRTNode &t)
86 {
87         std::vector<std::tuple<double, double>> p;
88         p.push_back(std::make_tuple(f.lfx(), f.lfy()));
89         p.push_back(std::make_tuple(f.lrx(), f.lry()));
90         p.push_back(std::make_tuple(f.rrx(), f.rry()));
91         p.push_back(std::make_tuple(f.rfx(), f.rfy()));
92         p.push_back(std::make_tuple(t.lfx(), t.lfy()));
93         p.push_back(std::make_tuple(t.lrx(), t.lry()));
94         p.push_back(std::make_tuple(t.rrx(), t.rry()));
95         p.push_back(std::make_tuple(t.rfx(), t.rfy()));
96         return this->collide(p);
97 }
98
99 double RRTS::cost_build(RRTNode &f, RRTNode &t)
100 {
101         double cost = 0;
102         cost = sqrt(pow(t.y() - f.y(), 2) + pow(t.x() - f.x(), 2));
103         return cost;
104 }
105
106 double RRTS::cost_search(RRTNode &f, RRTNode &t)
107 {
108         double cost = 0;
109         cost = sqrt(pow(t.y() - f.y(), 2) + pow(t.x() - f.x(), 2));
110         return cost;
111 }
112
113 void RRTS::sample()
114 {
115         double x = this->ndx_(this->gen_);
116         double y = this->ndy_(this->gen_);
117         double h = this->ndh_(this->gen_);
118         this->samples().push_back(RRTNode());
119         this->samples().back().x(x);
120         this->samples().back().y(y);
121         this->samples().back().h(h);
122 }
123
124 RRTNode *RRTS::nn(RRTNode &t)
125 {
126         RRTNode *nn = &this->nodes().front();
127         double cost = this->cost_search(*nn, t);
128         for (auto &f: this->nodes()) {
129                 if (this->cost_search(f, t) < cost) {
130                         nn = &f;
131                         cost = this->cost_search(f, t);
132                 }
133         }
134         return nn;
135 }
136
137 std::vector<RRTNode *> RRTS::nv(RRTNode &t)
138 {
139         std::vector<RRTNode *> nv;
140         double cost = std::min(GAMMA(this->nodes().size()), ETA);
141         for (auto &f: this->nodes())
142                 if (this->cost_search(f, t) < cost)
143                         nv.push_back(&f);
144         return nv;
145 }
146
147 int cb_rs_steer(double q[4], void *user_data)
148 {
149         std::vector<RRTNode> *nodes = (std::vector<RRTNode> *) user_data;
150         RRTNode *ln = nullptr;
151         if (nodes->size() > 0)
152                 ln = &nodes->back();
153         nodes->push_back(RRTNode());
154         nodes->back().x(q[0]);
155         nodes->back().y(q[1]);
156         nodes->back().h(q[2]);
157         nodes->back().sp(q[3]);
158         if (nodes->back().sp() == 0)
159                 nodes->back().set_t(RRTNodeType::cusp);
160         else if (ln != nullptr && sgn(ln->sp()) != sgn(nodes->back().sp()))
161                 ln->set_t(RRTNodeType::cusp);
162         return 0;
163 }
164
165 void RRTS::steer(RRTNode &f, RRTNode &t)
166 {
167         this->steered().clear();
168         double q0[] = {f.x(), f.y(), f.h()};
169         double q1[] = {t.x(), t.y(), t.h()};
170         ReedsSheppStateSpace rsss(f.mtr());
171         rsss.sample(q0, q1, 0.5, cb_rs_steer, &this->steered());
172 }
173
174 void RRTS::join_steered(RRTNode *f)
175 {
176         while (this->steered().size() > 0) {
177                 this->nodes().push_back(this->steered().front());
178                 RRTNode *t = &this->nodes().back();
179                 t->p(f);
180                 t->c(this->cost_build(*f, *t));
181                 this->steered().erase(this->steered().begin());
182                 f = t;
183         }
184 }
185
186 bool RRTS::goal_found(RRTNode &f)
187 {
188         bool found = false;
189         for (auto &g: this->goals()) {
190                 double cost = this->cost_build(f, g);
191                 double edist = sqrt(
192                         pow(f.x() - g.x(), 2)
193                         + pow(f.y() - g.y(), 2)
194                 );
195                 double adist = std::abs(f.h() - g.h());
196                 if (edist < 0.05 && adist < M_PI / 32) {
197                         found = true;
198                         if (g.p() == nullptr || cc(f) + cost < cc(g)) {
199                                 g.p(&f);
200                                 g.c(cost);
201                         }
202                 }
203         }
204         return found;
205 }
206
207 // RRT* procedures
208 bool RRTS::connect()
209 {
210         RRTNode *t = &this->steered().front();
211         RRTNode *f = this->nn(this->samples().back());
212         double cost = this->cost_search(*f, *t);
213         for (auto n: this->nv(*t)) {
214                 if (
215                         !std::get<0>(this->collide_two_nodes(*n, *t))
216                         && this->cost_search(*n, *t) < cost
217                 ) {
218                         f = n;
219                         cost = this->cost_search(*n, *t);
220                 }
221         }
222         this->nodes().push_back(this->steered().front());
223         t = &this->nodes().back();
224         t->p(f);
225         t->c(this->cost_build(*f, *t));
226         t->set_t(RRTNodeType::connected);
227         return true;
228 }
229
230 void RRTS::rewire()
231 {
232         RRTNode *f = &this->nodes().back();
233         for (auto n: this->nv(*f)) {
234                 if (
235                         !std::get<0>(this->collide_two_nodes(*f, *n))
236                         && cc(*f) + this->cost_search(*f, *n) < cc(*n)
237                 ) {
238                         n->p(f);
239                         n->c(this->cost_build(*f, *n));
240                 }
241         }
242 }
243
244 // API
245 std::vector<RRTNode *> RRTS::path()
246 {
247         std::vector<RRTNode *> path;
248         if (this->goals().size() == 0)
249                 return path;
250         RRTNode *goal = &this->goals().front();
251         for (auto &n: this->goals()) {
252                 if (
253                         n.p() != nullptr
254                         && (n.c() < goal->c() || goal->p() == nullptr)
255                 ) {
256                         goal = &n;
257                 }
258         }
259         if (goal->p() == nullptr)
260                 return path;
261         while (goal != nullptr) {
262                 path.push_back(goal);
263                 goal = goal->p();
264         }
265         std::reverse(path.begin(), path.end());
266         return path;
267 }
268
269 bool RRTS::next()
270 {
271         if (this->icnt_ == 0)
272                 this->tstart_ = std::chrono::high_resolution_clock::now();
273         bool next = true;
274         if (this->should_stop())
275                 return false;
276         this->sample();
277         this->steer(
278                 *this->nn(this->samples().back()),
279                 this->samples().back()
280         );
281         if (std::get<0>(this->collide_steered_from(
282                 *this->nn(this->samples().back())
283         )))
284                 return next;
285         if (!this->connect())
286                 return next;
287         this->rewire();
288         unsigned scnt = this->steered().size();
289         this->steered().erase(this->steered().begin());
290         this->join_steered(&this->nodes().back());
291         RRTNode *just_added = &this->nodes().back();
292         while (scnt > 0) {
293                 scnt--;
294                 for (auto &g: this->goals()) {
295                         this->steer(*just_added, g);
296                         if (std::get<0>(this->collide_steered_from(
297                                 *just_added
298                         )))
299                                 continue;
300                         this->join_steered(just_added);
301                 }
302                 this->gf(this->goal_found(this->nodes().back()));
303                 just_added = just_added->p();
304         }
305         return next;
306 }
307
308 void RRTS::set_sample(
309         double mx, double dx,
310         double my, double dy,
311         double mh, double dh
312 )
313 {
314         this->ndx_ = std::normal_distribution<double>(mx, dx);
315         this->ndy_ = std::normal_distribution<double>(my, dy);
316         this->ndh_ = std::normal_distribution<double>(mh, dh);
317 }
318
319 RRTS::RRTS()
320         : gen_(std::random_device{}())
321 {
322         this->goals().reserve(100);
323         this->nodes().reserve(4000000);
324         this->samples().reserve(1000);
325         this->steered().reserve(20000);
326         this->nodes().push_back(RRTNode()); // root
327 }
328
329 double cc(RRTNode &t)
330 {
331         RRTNode *n = &t;
332         double cost = 0;
333         while (n != nullptr) {
334                 cost += n->c();
335                 n = n->p();
336         }
337         return cost;
338 }