]> rtime.felk.cvut.cz Git - hubacji1/rrts.git/blob - src/rrts.cc
Use polygon argument for RRT* collide method
[hubacji1/rrts.git] / src / rrts.cc
1 #include <algorithm>
2 #include "rrts.h"
3
4 #define ETA 1.0 // for steer, nv
5 #define GAMMA(cV) ({ \
6         __typeof__ (cV) _cV = (cV); \
7         pow(log(_cV) / _cV, 1.0 / 3.0); \
8 })
9
10 RRTNode::RRTNode()
11 {
12 }
13
14 Obstacle::Obstacle()
15 {
16 }
17
18 // RRT procedures
19 bool RRTS::collide(std::vector<std::tuple<double, double>> &poly)
20 {
21         return false;
22 }
23
24 double RRTS::cost(RRTNode &f, RRTNode &t)
25 {
26         double cost = 0;
27         cost = sqrt(pow(t.y() - f.y(), 2) + pow(t.x() - f.x(), 2));
28         return cost;
29 }
30
31 void RRTS::sample()
32 {
33         double x = this->ndx_(this->gen_);
34         double y = this->ndy_(this->gen_);
35         double h = this->ndh_(this->gen_);
36         this->samples().push_back(RRTNode());
37         this->samples().back().x(x);
38         this->samples().back().y(y);
39         this->samples().back().h(h);
40 }
41
42 RRTNode *RRTS::nn(RRTNode &t)
43 {
44         RRTNode *nn = &this->nodes().front();
45         double cost = this->cost(*nn, t);
46         for (auto &f: this->nodes()) {
47                 if (this->cost(f, t) < cost) {
48                         nn = &f;
49                         cost = this->cost(f, t);
50                 }
51         }
52         return nn;
53 }
54
55 std::vector<RRTNode *> RRTS::nv(RRTNode &t)
56 {
57         std::vector<RRTNode *> nv;
58         double cost = std::min(GAMMA(this->nodes().size()), ETA);
59         for (auto &f: this->nodes())
60                 if (this->cost(f, t) < cost)
61                         nv.push_back(&f);
62         return nv;
63 }
64
65 void RRTS::steer(RRTNode &f, RRTNode &t)
66 {
67         double angl = atan2(t.y() - f.y(), t.x() - f.x());
68         this->steered().clear();
69         this->steered().push_back(RRTNode());
70         this->steered().back().x(f.x() + ETA * cos(angl));
71         this->steered().back().y(f.y() + ETA * sin(angl));
72         this->steered().back().h(angl);
73 }
74
75 // RRT* procedures
76 bool RRTS::connect()
77 {
78         bool conn = false;
79         RRTNode *t = &this->steered().front();
80         RRTNode *f = this->nn(this->samples().back());
81         double cost = this->cost(*f, *t);
82         for (auto n: this->nv(*t)) {
83                 if (this->cost(*n, *t) < cost) {
84                         f = n;
85                         cost = this->cost(*n, *t);
86                 }
87         }
88         this->nodes().push_back(this->steered().front());
89         this->steered().erase(this->steered().begin());
90         t = &this->nodes().back();
91         t->p(f);
92         t->c(this->cost(*f, *t));
93         conn = true;
94         return conn;
95 }
96
97 void RRTS::rewire()
98 {
99         RRTNode *f = &this->nodes().back();
100         for (auto n: this->nv(*f)) {
101                 if (cc(*f) + this->cost(*f, *n) < cc(*n))
102                         n->p(f);
103         }
104 }
105
106 // API
107 std::vector<RRTNode *> RRTS::path()
108 {
109         std::vector<RRTNode *> path;
110         if (this->goals().size() == 0)
111                 return path;
112         RRTNode *goal = &this->goals().front();
113         for (auto &n: this->goals()) {
114                 if (
115                         n.p() != nullptr
116                         && (n.c() < goal->c() || goal->p() == nullptr)
117                 ) {
118                         goal = &n;
119                 }
120         }
121         if (goal->p() == nullptr)
122                 return path;
123         while (goal != nullptr) {
124                 path.push_back(goal);
125                 goal = goal->p();
126         }
127         std::reverse(path.begin(), path.end());
128         return path;
129 }
130
131 bool RRTS::next()
132 {
133         bool next = true;
134         this->icnt_++;
135         this->sample();
136         this->steer(
137                 *this->nn(this->samples().back()),
138                 this->samples().back()
139         );
140         this->connect();
141         this->rewire();
142         for (auto &n: this->goals()) {
143                 double cost = this->cost(this->nodes().back(), n);
144                 if (cost < ETA) {
145                         next = false;
146                         if (
147                                 n.p() == nullptr
148                                 || cc(this->nodes().back()) + cost < cc(n)
149                         ) {
150                                 n.p(&this->nodes().back());
151                                 n.c(cost);
152                         }
153                 }
154         }
155         if (this->icnt_ > 999)
156                 next = false;
157         return next;
158 }
159
160 void RRTS::set_sample(
161         double mx, double dx,
162         double my, double dy,
163         double mh, double dh
164 )
165 {
166         this->ndx_ = std::normal_distribution<double>(mx, dx);
167         this->ndy_ = std::normal_distribution<double>(my, dy);
168         this->ndh_ = std::normal_distribution<double>(mh, dh);
169 }
170
171 RRTS::RRTS()
172         : gen_(std::random_device{}())
173 {
174         this->goals().reserve(1);
175         this->nodes().reserve(20000);
176         this->samples().reserve(1000);
177         this->steered().reserve(20);
178         this->nodes().push_back(RRTNode()); // root
179 }
180
181 double cc(RRTNode &t)
182 {
183         RRTNode *n = &t;
184         double cost = 0;
185         while (n != nullptr) {
186                 cost += n->c();
187                 n = n->p();
188         }
189         return cost;
190 }