]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - decision_control/rrtplanner.cc
Add RRT testing planner `T1`
[hubacji1/iamcar.git] / decision_control / rrtplanner.cc
1 /*
2 This file is part of I am car.
3
4 I am car is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, either version 3 of the License, or
7 (at your option) any later version.
8
9 I am car is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with I am car. If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <cstdlib>
19 #include <ctime>
20 #include "nn.h"
21 #include "nv.h"
22 #include "sample.h"
23 #include "steer.h"
24 #include "rrtplanner.h"
25 #include "cost.h"
26
27 #define KUWATA2008_CCOST co3
28 #define KUWATA2008_DCOST co1
29
30 LaValle1998::LaValle1998(RRTNode *init, RRTNode *goal):
31         RRTBase(init, goal),
32         nn(nn1),
33         sample(sa1),
34         steer(st1),
35         cost(co1)
36 {
37         srand(static_cast<unsigned>(time(0)));
38 }
39
40 bool LaValle1998::next()
41 {
42         RRTNode *rs = this->sample();
43         this->samples().push_back(rs);
44         RRTNode *nn = this->nn(this->root(), rs, this->cost);
45         RRTNode *pn = nn;
46         for (auto ns: this->steer(nn, rs)) {
47                 this->nodes().push_back(ns);
48                 pn->add_child(ns, this->cost(pn, ns));
49                 if (this->collide(pn, ns)) {
50                         pn->children().pop_back();
51                         break;
52                 }
53                 pn = ns;
54                 if (this->goal_found(pn, this->cost)) {
55                         this->tlog(this->findt());
56                         break;
57                 }
58         }
59         return this->goal_found();
60 }
61
62 Kuwata2008::Kuwata2008(RRTNode *init, RRTNode *goal):
63         RRTBase(init, goal),
64         nn(nn1),
65         sample(sa1),
66         steer(st1),
67         cost(KUWATA2008_DCOST)
68 {
69         srand(static_cast<unsigned>(time(0)));
70 }
71
72 bool Kuwata2008::next()
73 {
74         RRTNode *rs;
75         if (this->samples().size() == 0) {
76                 rs = this->goal();
77         } else {
78                 rs = this->sample();
79         }
80         this->samples().push_back(rs);
81         float heur = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
82         if (this->goal_found()) {
83                 if (heur < 0.7)
84                         this->cost = &KUWATA2008_CCOST;
85                 else
86                         this->cost = &KUWATA2008_DCOST;
87         } else {
88                 if (heur < 0.3)
89                         this->cost = &KUWATA2008_CCOST;
90                 else
91                         this->cost = &KUWATA2008_DCOST;
92         }
93         RRTNode *nn = this->nn(this->root(), rs, this->cost);
94         RRTNode *pn = nn;
95         std::vector<RRTNode *> newly_added;
96         for (auto ns: this->steer(nn, rs)) {
97                 this->nodes().push_back(ns);
98                 pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
99                 if (this->collide(pn, ns)) {
100                         pn->children().pop_back();
101                         break;
102                 }
103                 pn = ns;
104                 newly_added.push_back(pn);
105                 if (this->goal_found(pn, &KUWATA2008_DCOST)) {
106                         this->tlog(this->findt());
107                         break;
108                 }
109         }
110         if (this->samples().size() > 1) {
111                 for (auto na: newly_added) {
112                         pn = na;
113                         for (auto ns: this->steer(na, this->goal())) {
114                                 this->nodes().push_back(ns);
115                                 pn->add_child(ns, KUWATA2008_DCOST(pn, ns));
116                                 if (this->collide(pn, ns)) {
117                                         pn->children().pop_back();
118                                         break;
119                                 }
120                                 pn = ns;
121                                 if (this->goal_found(pn, &KUWATA2008_DCOST)) {
122                                         this->tlog(this->findt());
123                                         break;
124                                 }
125                         }
126                 }
127         }
128         return this->goal_found();
129 }
130
131 Karaman2011::Karaman2011(RRTNode *init, RRTNode *goal):
132         RRTBase(init, goal),
133         nn(nn1),
134         nv(nv1),
135         sample(sa1),
136         steer(st1),
137         cost(co1)
138 {
139         srand(static_cast<unsigned>(time(0)));
140 }
141
142 bool Karaman2011::next()
143 {
144         RRTNode *rs = this->sample();
145         this->samples().push_back(rs);
146         RRTNode *nn = this->nn(this->root(), rs, this->cost);
147         RRTNode *pn = nn;
148         std::vector<RRTNode *> nvs;
149         bool connected;
150         RRTNode *op; // old parent
151         float od; // old direct cost
152         float oc; // old cumulative cost
153         for (auto ns: this->steer(nn, rs)) {
154                 nvs = this->nv(this->root(), ns, this->cost, MIN(
155                                         GAMMA_RRTSTAR(this->nodes().size()),
156                                         1)); // TODO const
157                 this->nodes().push_back(ns);
158                 connected = false;
159                 pn->add_child(ns, this->cost(pn, ns));
160                 if (this->collide(pn, ns)) {
161                         pn->children().pop_back();
162                 } else {
163                         connected = true;
164                 }
165                 // connect
166                 for (auto nv: nvs) {
167                         if (!connected || (nv->ccost() + this->cost(nv, ns) <
168                                         ns->ccost())) {
169                                 op = ns->parent();
170                                 od = ns->dcost();
171                                 oc = ns->ccost();
172                                 nv->add_child(ns, this->cost(nv, ns));
173                                 if (this->collide(nv, ns)) {
174                                         nv->children().pop_back();
175                                         ns->parent(op);
176                                         ns->dcost(od);
177                                         ns->ccost(oc);
178                                 } else if (connected) {
179                                         op->children().pop_back();
180                                 } else {
181                                         connected = true;
182                                 }
183                         }
184                 }
185                 if (!connected)
186                         return false;
187                 // rewire
188                 for (auto nv: nvs) {
189                         if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
190                                 op = nv->parent();
191                                 od = nv->dcost();
192                                 oc = nv->ccost();
193                                 ns->add_child(nv, this->cost(ns, nv));
194                                 if (this->collide(ns, nv)) {
195                                         ns->children().pop_back();
196                                         nv->parent(op);
197                                         nv->dcost(od);
198                                         nv->ccost(oc);
199                                 } else {
200                                         op->rem_child(nv);
201                                 }
202                         }
203                 }
204                 pn = ns;
205                 if (this->goal_found(pn, this->cost)) {
206                         this->tlog(this->findt());
207                         break;
208                 }
209         }
210         return this->goal_found();
211 }
212
213 T1::T1(RRTNode *init, RRTNode *goal):
214         RRTBase(init, goal),
215         nn(nn1),
216         nv(nv1),
217         sample(sa1),
218         steer(st1),
219         cost(co1)
220 {
221         srand(static_cast<unsigned>(time(0)));
222 }
223
224 bool T1::next()
225 {
226         RRTNode *rs;
227         if (this->samples().size() == 0)
228                 rs = this->goal();
229         else
230                 rs = this->sample();
231         this->samples().push_back(rs);
232         RRTNode *nn = this->nn(this->root(), rs, this->cost);
233         RRTNode *pn = nn;
234         std::vector<RRTNode *> nvs;
235         bool connected;
236         RRTNode *op; // old parent
237         float od; // old direct cost
238         float oc; // old cumulative cost
239         std::vector<RRTNode *> steered = this->steer(nn, rs);
240         // RRT* for first node
241         RRTNode *ns = steered[0];
242         {
243                 nvs = this->nv(this->root(), ns, this->cost, MIN(
244                                         GAMMA_RRTSTAR(this->nodes().size()),
245                                         0.2)); // TODO const
246                 this->nodes().push_back(ns);
247                 connected = false;
248                 pn->add_child(ns, this->cost(pn, ns));
249                 if (this->collide(pn, ns)) {
250                         pn->children().pop_back();
251                 } else {
252                         connected = true;
253                 }
254                 // connect
255                 for (auto nv: nvs) {
256                         if (!connected || (nv->ccost() + this->cost(nv, ns) <
257                                         ns->ccost())) {
258                                 op = ns->parent();
259                                 od = ns->dcost();
260                                 oc = ns->ccost();
261                                 nv->add_child(ns, this->cost(nv, ns));
262                                 if (this->collide(nv, ns)) {
263                                         nv->children().pop_back();
264                                         ns->parent(op);
265                                         ns->dcost(od);
266                                         ns->ccost(oc);
267                                 } else if (connected) {
268                                         op->children().pop_back();
269                                 } else {
270                                         connected = true;
271                                 }
272                         }
273                 }
274                 if (!connected)
275                         return false;
276                 // rewire
277                 for (auto nv: nvs) {
278                         if (ns->ccost() + this->cost(ns, nv) < nv->ccost()) {
279                                 op = nv->parent();
280                                 od = nv->dcost();
281                                 oc = nv->ccost();
282                                 ns->add_child(nv, this->cost(ns, nv));
283                                 if (this->collide(ns, nv)) {
284                                         ns->children().pop_back();
285                                         nv->parent(op);
286                                         nv->dcost(od);
287                                         nv->ccost(oc);
288                                 } else {
289                                         op->rem_child(nv);
290                                 }
291                         }
292                 }
293                 pn = ns;
294                 if (this->goal_found(pn, this->cost)) {
295                         this->tlog(this->findt());
296                 }
297         }
298         unsigned int i = 0;
299         for (i = 1; i < steered.size(); i++) {
300                 ns = steered[i];
301                 this->nodes().push_back(ns);
302                 pn->add_child(ns, this->cost(pn, ns));
303                 if (this->collide(pn, ns)) {
304                         pn->children().pop_back();
305                         break;
306                 }
307                 pn = ns;
308                 if (this->goal_found(pn, this->cost)) {
309                         this->tlog(this->findt());
310                         break;
311                 }
312         }
313         return this->goal_found();
314 }