]> rtime.felk.cvut.cz Git - hubacji1/rrts.git/blob - src/rrts.cc
9b4fcb930a079e6340d2166bc5d6d27938ed10a4
[hubacji1/rrts.git] / src / rrts.cc
1 #include <algorithm>
2 #include "rrts.h"
3
4 #include "reeds_shepp.h"
5
6 #define ETA 1.0 // for steer, nv
7 #define GAMMA(cV) ({ \
8         __typeof__ (cV) _cV = (cV); \
9         pow(log(_cV) / _cV, 1.0 / 3.0); \
10 })
11
12 RRTNode::RRTNode()
13 {
14 }
15
16 RRTNode::RRTNode(const BicycleCar &bc) : BicycleCar(bc)
17 {
18 }
19
20 Obstacle::Obstacle()
21 {
22 }
23
24 double RRTS::elapsed()
25 {
26         std::chrono::duration<double> dt;
27         dt = std::chrono::duration_cast<std::chrono::duration<double>>(
28                 std::chrono::high_resolution_clock::now()
29                 - this->tstart_
30         );
31         this->scnt_ = dt.count();
32         return this->scnt_;
33 }
34
35 bool RRTS::should_stop()
36 {
37         // the following counters must be updated, do not comment
38         this->icnt_++;
39         this->elapsed();
40         // decide the stop conditions (maybe comment some lines)
41         if (this->icnt_ > 999) return true;
42         if (this->scnt_ > 10) return true;
43         if (this->gf()) return true;
44         // but continue by default
45         return false;
46 }
47
48 // RRT procedures
49 std::tuple<bool, unsigned int, unsigned int>
50 RRTS::collide(std::vector<std::tuple<double, double>> &poly)
51 {
52         for (auto &o: this->obstacles())
53                 if (std::get<0>(::collide(poly, o.poly())))
54                         return ::collide(poly, o.poly());
55         return std::make_tuple(false, 0, 0);
56 }
57
58 std::tuple<bool, unsigned int, unsigned int>
59 RRTS::collide_steered_from(RRTNode &f)
60 {
61         std::vector<std::tuple<double, double>> s;
62         s.push_back(std::make_tuple(f.x(), f.y()));
63         for (auto &n: this->steered()) {
64                 s.push_back(std::make_tuple(n.lfx(), n.lfy()));
65                 s.push_back(std::make_tuple(n.lrx(), n.lry()));
66                 s.push_back(std::make_tuple(n.rrx(), n.rry()));
67                 s.push_back(std::make_tuple(n.rfx(), n.rfy()));
68         }
69         auto col = this->collide(s);
70         auto strip_from = this->steered().size() - std::get<1>(col) / 4;
71         if (std::get<0>(col) && strip_from > 0) {
72                 while (strip_from-- > 0) {
73                         this->steered().pop_back();
74                 }
75                 return this->collide_steered_from(f);
76         }
77         return col;
78 }
79
80 std::tuple<bool, unsigned int, unsigned int>
81 RRTS::collide_two_nodes(RRTNode &f, RRTNode &t)
82 {
83         std::vector<std::tuple<double, double>> p;
84         p.push_back(std::make_tuple(f.lfx(), f.lfy()));
85         p.push_back(std::make_tuple(f.lrx(), f.lry()));
86         p.push_back(std::make_tuple(f.rrx(), f.rry()));
87         p.push_back(std::make_tuple(f.rfx(), f.rfy()));
88         p.push_back(std::make_tuple(t.lfx(), t.lfy()));
89         p.push_back(std::make_tuple(t.lrx(), t.lry()));
90         p.push_back(std::make_tuple(t.rrx(), t.rry()));
91         p.push_back(std::make_tuple(t.rfx(), t.rfy()));
92         return this->collide(p);
93 }
94
95 double RRTS::cost_build(RRTNode &f, RRTNode &t)
96 {
97         double cost = 0;
98         cost = sqrt(pow(t.y() - f.y(), 2) + pow(t.x() - f.x(), 2));
99         return cost;
100 }
101
102 double RRTS::cost_search(RRTNode &f, RRTNode &t)
103 {
104         double cost = 0;
105         cost = sqrt(pow(t.y() - f.y(), 2) + pow(t.x() - f.x(), 2));
106         return cost;
107 }
108
109 void RRTS::sample()
110 {
111         double x = this->ndx_(this->gen_);
112         double y = this->ndy_(this->gen_);
113         double h = this->ndh_(this->gen_);
114         this->samples().push_back(RRTNode());
115         this->samples().back().x(x);
116         this->samples().back().y(y);
117         this->samples().back().h(h);
118 }
119
120 RRTNode *RRTS::nn(RRTNode &t)
121 {
122         RRTNode *nn = &this->nodes().front();
123         double cost = this->cost_search(*nn, t);
124         for (auto &f: this->nodes()) {
125                 if (this->cost_search(f, t) < cost) {
126                         nn = &f;
127                         cost = this->cost_search(f, t);
128                 }
129         }
130         return nn;
131 }
132
133 std::vector<RRTNode *> RRTS::nv(RRTNode &t)
134 {
135         std::vector<RRTNode *> nv;
136         double cost = std::min(GAMMA(this->nodes().size()), ETA);
137         for (auto &f: this->nodes())
138                 if (this->cost_search(f, t) < cost)
139                         nv.push_back(&f);
140         return nv;
141 }
142
143 int cb_rs_steer(double q[3], void *user_data)
144 {
145         std::vector<RRTNode> *nodes = (std::vector<RRTNode> *) user_data;
146         nodes->push_back(RRTNode());
147         nodes->back().x(q[0]);
148         nodes->back().y(q[1]);
149         nodes->back().h(q[2]);
150         return 0;
151 }
152
153 void RRTS::steer(RRTNode &f, RRTNode &t)
154 {
155         this->steered().clear();
156         double q0[] = {f.x(), f.y(), f.h()};
157         double q1[] = {t.x(), t.y(), t.h()};
158         ReedsSheppStateSpace rsss(f.mtr());
159         rsss.sample(q0, q1, 0.5, cb_rs_steer, &this->steered());
160 }
161
162 void RRTS::join_steered(RRTNode *f)
163 {
164         while (this->steered().size() > 0) {
165                 this->nodes().push_back(this->steered().front());
166                 RRTNode *t = &this->nodes().back();
167                 t->p(f);
168                 t->c(this->cost_build(*f, *t));
169                 this->steered().erase(this->steered().begin());
170                 f = t;
171         }
172 }
173
174 bool RRTS::goal_found(RRTNode &f)
175 {
176         bool found = false;
177         for (auto &g: this->goals()) {
178                 double cost = this->cost_build(f, g);
179                 double edist = sqrt(
180                         pow(f.x() - g.x(), 2)
181                         + pow(f.y() - g.y(), 2)
182                 );
183                 double adist = std::abs(f.h() - g.h());
184                 if (edist < 0.05 && adist < M_PI / 32) {
185                         found = true;
186                         if (g.p() == nullptr || cc(f) + cost < cc(g)) {
187                                 g.p(&f);
188                                 g.c(cost);
189                         }
190                 }
191         }
192         return found;
193 }
194
195 // RRT* procedures
196 bool RRTS::connect()
197 {
198         RRTNode *t = &this->steered().front();
199         RRTNode *f = this->nn(this->samples().back());
200         double cost = this->cost_search(*f, *t);
201         for (auto n: this->nv(*t)) {
202                 if (
203                         !std::get<0>(this->collide_two_nodes(*n, *t))
204                         && this->cost_search(*n, *t) < cost
205                 ) {
206                         f = n;
207                         cost = this->cost_search(*n, *t);
208                 }
209         }
210         this->nodes().push_back(this->steered().front());
211         t = &this->nodes().back();
212         t->p(f);
213         t->c(this->cost_build(*f, *t));
214         t->set_t(RRTNodeType::connected);
215         return true;
216 }
217
218 void RRTS::rewire()
219 {
220         RRTNode *f = &this->nodes().back();
221         for (auto n: this->nv(*f)) {
222                 if (
223                         !std::get<0>(this->collide_two_nodes(*f, *n))
224                         && cc(*f) + this->cost_search(*f, *n) < cc(*n)
225                 ) {
226                         n->p(f);
227                         n->c(this->cost_build(*f, *n));
228                 }
229         }
230 }
231
232 // API
233 std::vector<RRTNode *> RRTS::path()
234 {
235         std::vector<RRTNode *> path;
236         if (this->goals().size() == 0)
237                 return path;
238         RRTNode *goal = &this->goals().front();
239         for (auto &n: this->goals()) {
240                 if (
241                         n.p() != nullptr
242                         && (n.c() < goal->c() || goal->p() == nullptr)
243                 ) {
244                         goal = &n;
245                 }
246         }
247         if (goal->p() == nullptr)
248                 return path;
249         while (goal != nullptr) {
250                 path.push_back(goal);
251                 goal = goal->p();
252         }
253         std::reverse(path.begin(), path.end());
254         return path;
255 }
256
257 bool RRTS::next()
258 {
259         if (this->icnt_ == 0)
260                 this->tstart_ = std::chrono::high_resolution_clock::now();
261         bool next = true;
262         if (this->should_stop())
263                 return false;
264         this->sample();
265         this->steer(
266                 *this->nn(this->samples().back()),
267                 this->samples().back()
268         );
269         if (std::get<0>(this->collide_steered_from(
270                 *this->nn(this->samples().back())
271         )))
272                 return next;
273         if (!this->connect())
274                 return next;
275         this->rewire();
276         unsigned scnt = this->steered().size();
277         this->steered().erase(this->steered().begin());
278         this->join_steered(&this->nodes().back());
279         RRTNode *just_added = &this->nodes().back();
280         while (scnt > 0) {
281                 scnt--;
282                 for (auto &g: this->goals()) {
283                         this->steer(*just_added, g);
284                         if (std::get<0>(this->collide_steered_from(
285                                 *just_added
286                         )))
287                                 continue;
288                         this->join_steered(just_added);
289                 }
290                 this->gf(this->goal_found(this->nodes().back()));
291                 just_added = just_added->p();
292         }
293         return next;
294 }
295
296 void RRTS::set_sample(
297         double mx, double dx,
298         double my, double dy,
299         double mh, double dh
300 )
301 {
302         this->ndx_ = std::normal_distribution<double>(mx, dx);
303         this->ndy_ = std::normal_distribution<double>(my, dy);
304         this->ndh_ = std::normal_distribution<double>(mh, dh);
305 }
306
307 RRTS::RRTS()
308         : gen_(std::random_device{}())
309 {
310         this->goals().reserve(100);
311         this->nodes().reserve(4000000);
312         this->samples().reserve(1000);
313         this->steered().reserve(20000);
314         this->nodes().push_back(RRTNode()); // root
315 }
316
317 double cc(RRTNode &t)
318 {
319         RRTNode *n = &t;
320         double cost = 0;
321         while (n != nullptr) {
322                 cost += n->c();
323                 n = n->p();
324         }
325         return cost;
326 }