]> rtime.felk.cvut.cz Git - hubacji1/rrts.git/blob - src/rrts.cc
Add method for setting sampling info
[hubacji1/rrts.git] / src / rrts.cc
1 #include "rrts.h"
2
3 #define ETA 1.0 // for steer, nv
4 #define GAMMA(cV) ({ \
5         __typeof__ (cV) _cV = (cV); \
6         pow(log(_cV) / _cV, 1.0 / 3.0); \
7 })
8
9 RRTNode::RRTNode()
10 {
11 }
12
13 // RRT procedures
14 bool RRTS::collide()
15 {
16         bool collide = false;
17         return collide;
18 }
19
20 double RRTS::cost(RRTNode &f, RRTNode &t)
21 {
22         double cost = 0;
23         cost = sqrt(pow(t.y() - f.y(), 2) + pow(t.x() - f.x(), 2));
24         return cost;
25 }
26
27 void RRTS::sample()
28 {
29         double x = this->ndx_(this->gen_);
30         double y = this->ndy_(this->gen_);
31         double h = this->ndh_(this->gen_);
32         this->samples().push_back(RRTNode());
33         this->samples().back().x(x);
34         this->samples().back().y(y);
35         this->samples().back().h(h);
36 }
37
38 RRTNode *RRTS::nn(RRTNode &t)
39 {
40         RRTNode *nn = &this->nodes().front();
41         double cost = this->cost(*nn, t);
42         for (auto &f: this->nodes()) {
43                 if (this->cost(f, t) < cost) {
44                         nn = &f;
45                         cost = this->cost(f, t);
46                 }
47         }
48         return nn;
49 }
50
51 std::vector<RRTNode *> RRTS::nv(RRTNode &t)
52 {
53         std::vector<RRTNode *> nv;
54         double cost = std::min(GAMMA(this->nodes().size()), ETA);
55         for (auto &f: this->nodes())
56                 if (this->cost(f, t) < cost)
57                         nv.push_back(&f);
58         return nv;
59 }
60
61 void RRTS::steer(RRTNode &f, RRTNode &t)
62 {
63         double angl = atan2(t.y() - f.y(), t.x() - f.x());
64         this->steered().clear();
65         this->steered().push_back(RRTNode());
66         this->steered().back().x(f.x() + ETA * cos(angl));
67         this->steered().back().y(f.y() + ETA * sin(angl));
68         this->steered().back().h(angl);
69 }
70
71 // RRT* procedures
72 bool RRTS::connect()
73 {
74         bool conn = false;
75         RRTNode *t = &this->steered().front();
76         RRTNode *f = this->nn(this->samples().back());
77         double cost = this->cost(*f, *t);
78         for (auto n: this->nv(*t)) {
79                 if (this->cost(*n, *t) < cost) {
80                         f = n;
81                         cost = this->cost(*n, *t);
82                 }
83         }
84         this->nodes().push_back(this->steered().front());
85         this->steered().erase(this->steered().begin());
86         t = &this->nodes().back();
87         t->p(f);
88         t->c(f->c() + this->cost(*f, *t));
89         conn = true;
90         return conn;
91 }
92
93 void RRTS::rewire()
94 {
95         RRTNode *f = &this->nodes().back();
96         for (auto n: this->nv(*f)) {
97                 if (f->c() + this->cost(*f, *n) < n->c())
98                         n->p(f);
99         }
100 }
101
102 // API
103 std::vector<RRTNode *> RRTS::path()
104 {
105         std::vector<RRTNode *> path;
106         return path;
107 }
108
109 bool RRTS::next()
110 {
111         bool next = true;
112         this->icnt_++;
113         this->sample();
114         this->steer(
115                 *this->nn(this->samples().back()),
116                 this->samples().back()
117         );
118         this->connect();
119         this->rewire();
120         for (auto &n: this->goals()) {
121                 double cost = this->cost(this->nodes().back(), n);
122                 if (cost < ETA) {
123                         next = false;
124                         if (
125                                 n.p() == nullptr
126                                 || this->nodes().back().c() + cost < n.c()
127                         ) {
128                                 n.p(&this->nodes().back());
129                                 n.c(this->nodes().back().c() + cost);
130                         }
131                 }
132         }
133         if (this->icnt_ > 999)
134                 next = false;
135         return next;
136 }
137
138 void RRTS::set_sample(
139         double mx, double dx,
140         double my, double dy,
141         double mh, double dh
142 )
143 {
144         this->ndx_ = std::normal_distribution<double>(mx, dx);
145         this->ndy_ = std::normal_distribution<double>(my, dy);
146         this->ndh_ = std::normal_distribution<double>(mh, dh);
147 }
148
149 RRTS::RRTS()
150         : gen_(std::random_device{}())
151 {
152         this->nodes().push_back(RRTNode()); // root
153 }