]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - plot.py
43f75259fa1c0441ae60edb174be2f4a09a5ddb6
[hubacji1/iamcar.git] / plot.py
1 # -*- coding: utf-8 -*-
2 """This scipt loads scenario and result trajectory from files and plots it."""
3 from json import loads
4 from math import copysign, cos, pi, sin
5 from matplotlib import pyplot as plt
6 from sys import argv
7
8 sign = lambda x: copysign(1, x)
9
10 HEIGHT = 1.450
11 LENGTH = 3.760
12 SAFETY_DIST = 0
13 WHEEL_BASE = 2.450
14 WIDTH = 1.625
15
16 SCEN_FILE = argv[1]
17 TRAJ_FILE = argv[2]
18 PLOT = {
19         "enable" : True,
20         "node" : False,
21         "edge" : True,
22         "sample" : False,
23         "frame" : False,
24         "path" : False,
25         "traj" : True,
26         }
27 COLOR = {
28         "node": "lightgrey",
29         "edge": "lightgrey",
30         "start": "violet",
31         "goal": "red",
32         "sample": "magenta",
33         "path": "grey",
34         "trajectory": "blue",
35         "trajectory-frame": "lightblue",
36         "obstacle": "black",
37         "log": ("gold", "orange", "blueviolet", "blue", "navy", "black"),
38         }
39
40 def car_frame(pose):
41     """Return `xcoords`, `ycoords` arrays of car frame.
42
43     Keyword arguments:
44     pose -- The pose of a car.
45     """
46     dr = (LENGTH - WHEEL_BASE) / 2
47     df = LENGTH - dr
48
49     lfx = pose[0]
50     lfx += (WIDTH / 2.0) * cos(pose[2] + pi / 2.0)
51     lfx += df * cos(pose[2])
52     lfx += SAFETY_DIST * cos(pose[2])
53
54     lrx = pose[0]
55     lrx += (WIDTH / 2.0) * cos(pose[2] + pi / 2.0)
56     lrx += -dr * cos(pose[2])
57     lrx += -SAFETY_DIST * cos(pose[2])
58
59     rrx = pose[0]
60     rrx += (WIDTH / 2.0) * cos(pose[2] - pi / 2.0)
61     rrx += -dr * cos(pose[2])
62     rrx += -SAFETY_DIST * cos(pose[2])
63
64     rfx = pose[0]
65     rfx += (WIDTH / 2.0) * cos(pose[2] - pi / 2.0)
66     rfx += df * cos(pose[2])
67     rfx += SAFETY_DIST * cos(pose[2])
68
69     lfy = pose[1]
70     lfy += (WIDTH / 2.0) * sin(pose[2] + pi / 2.0)
71     lfy += df * sin(pose[2])
72     lfy += SAFETY_DIST * sin(pose[2])
73
74     lry = pose[1]
75     lry += (WIDTH / 2.0) * sin(pose[2] + pi / 2.0)
76     lry += -dr * sin(pose[2])
77     lry += -SAFETY_DIST * sin(pose[2])
78
79     rry = pose[1]
80     rry += (WIDTH / 2.0) * sin(pose[2] - pi / 2.0)
81     rry += -dr * sin(pose[2])
82     rry += -SAFETY_DIST * sin(pose[2])
83
84     rfy = pose[1]
85     rfy += (WIDTH / 2.0) * sin(pose[2] - pi / 2.0)
86     rfy += df * sin(pose[2])
87     rfy += SAFETY_DIST * sin(pose[2])
88
89     xcoords = (lfx, lrx, rrx, rfx)
90     ycoords = (lfy, lry, rry, rfy)
91     return (xcoords, ycoords)
92
93 def load_scenario(fname):
94     """Load scenario from file."""
95     if fname is None:
96         raise ValueError("File name as argument needed")
97     with open(fname, "r") as f:
98         scenario = loads(f.read())
99     return scenario
100
101 def load_trajectory(fname):
102     """Load trajectory from file."""
103     if fname is None:
104         raise ValueError("File name as argument needed")
105     with open(fname, "r") as f:
106         trajectory = loads(f.read())
107         return trajectory
108
109 def plot_nodes(nodes=[]):
110     """Return xcoords, ycoords of nodes to plot.
111
112     Keyword arguments:
113     nodes -- The list of nodes to plot.
114     """
115     xcoords = []
116     ycoords = []
117     for n in nodes:
118         xcoords.append(n[0])
119         ycoords.append(n[1])
120     return (xcoords, ycoords)
121
122 def plot_segments(segments=[]):
123     """Return xcoords, ycoords of segments to plot.
124
125     Keyword arguments:
126     segments -- The list of segments to plot.
127     """
128     pass
129
130 if __name__ == "__main__":
131     s = load_scenario(SCEN_FILE)
132     try:
133         t = load_trajectory(TRAJ_FILE) # fixed to trajectories
134     except:
135         pass
136
137     plt.rcParams["font.size"] = 24
138     fig = plt.figure()
139     ## 1st subplot
140     ax = fig.add_subplot(121)
141     ax.set_aspect("equal")
142     ax.set_title("RRT* final path")
143     ax.set_xlabel("Longitudinal direction [m]")
144     ax.set_ylabel("Lateral direction [m]")
145
146     # plot here
147     for o in s["obst"]:
148         try:
149             plt.plot(*plot_nodes(o["segment"]), color="black")
150         except:
151             pass
152         try:
153             ax.add_artist(plt.Circle((o["circle"][0], o["circle"][1]),
154                     o["circle"][2],
155                     color="black", fill=False))
156         except:
157             pass
158     if PLOT["node"]:
159         try:
160             plt.plot(*plot_nodes(t["node"]), color=COLOR["node"],
161                     marker=".", linestyle = "None")
162         except:
163             print("No RRTNode")
164     if False:
165         try:
166             for edges in t["edge"]:
167                 for e in edges:
168                     plt.plot([e[0][0], e[1][0]], [e[0][1], e[1][1]],
169                             color=COLOR["edge"])
170         except:
171             print("No edges")
172     if PLOT["sample"]:
173         try:
174             if PLOT["frame"]:
175                 for i in t["samp"]:
176                     plt.plot(*car_frame(i), color=COLOR["sample"])
177             else:
178                 plt.plot(*plot_nodes(t["samp"]), color=COLOR["sample"],
179                         marker=".", linestyle = "None")
180         except:
181             print("No RRTSample")
182     if PLOT["path"]:
183         try:
184             for path in range(len(t["path"])):
185                 plt.plot(*plot_nodes(t["path"][path]), color=COLOR["path"])
186         except:
187             print("No path")
188     if PLOT["traj"]:
189         try:
190             for traj in range(4):#len(t["traj"])):
191                 if PLOT["frame"]:
192                     for i in t["traj"][traj]:
193                         plt.plot(*car_frame(i), color=COLOR["log"][traj],
194                                 label=t["cost"][traj])
195                 else:
196                     try:
197                         plt.plot(
198                                 *plot_nodes(t["traj"][traj]),
199                                 color=COLOR["log"][traj],
200                                 label=t["cost"][traj])
201                     except:
202                         plt.plot(
203                                 *plot_nodes(t["traj"][traj]),
204                                 color="black",
205                                 label=t["cost"][traj])
206         except:
207             print("No trajectory")
208     plt.plot(*plot_nodes([s["init"]]), color=COLOR["start"], marker=".")
209     plt.plot(*plot_nodes([s["goal"]]), color=COLOR["goal"], marker=".")
210     # end plot here
211
212     handles, labels = ax.get_legend_handles_labels()
213     #lgd = ax.legend(handles, labels, loc="upper center",
214     #        bbox_to_anchor=(0.5, -0.11), title="Cost")
215     ## 2nd subplot
216     ax = fig.add_subplot(122)
217     ax.set_aspect("equal")
218     ax.set_title("RRT* all edges")
219     ax.set_xlabel("Longitudinal direction [m]")
220     ax.set_ylabel("Lateral direction [m]")
221
222     # plot here
223     for o in s["obst"]:
224         try:
225             plt.plot(*plot_nodes(o["segment"]), color="black")
226         except:
227             pass
228         try:
229             ax.add_artist(plt.Circle((o["circle"][0], o["circle"][1]),
230                     o["circle"][2],
231                     color="black", fill=False))
232         except:
233             pass
234     if PLOT["node"]:
235         try:
236             plt.plot(*plot_nodes(t["node"]), color=COLOR["node"],
237                     marker=".", linestyle = "None")
238         except:
239             print("No RRTNode")
240     if True:
241         try:
242             for edges in t["edge"]:
243                 for e in edges:
244                     plt.plot([e[0][0], e[1][0]], [e[0][1], e[1][1]],
245                             color=COLOR["edge"])
246         except:
247             print("No edges")
248     if PLOT["sample"]:
249         try:
250             if PLOT["frame"]:
251                 for i in t["samp"]:
252                     plt.plot(*car_frame(i), color=COLOR["sample"])
253             else:
254                 plt.plot(*plot_nodes(t["samp"]), color=COLOR["sample"],
255                         marker=".", linestyle = "None")
256         except:
257             print("No RRTSample")
258     if PLOT["path"]:
259         try:
260             for path in range(len(t["path"])):
261                 plt.plot(*plot_nodes(t["path"][path]), color=COLOR["path"])
262         except:
263             print("No path")
264     if False:
265         try:
266             for traj in range(len(t["traj"])):
267                 if PLOT["frame"]:
268                     for i in t["traj"][traj]:
269                         plt.plot(*car_frame(i), color=COLOR["log"][traj],
270                                 label=t["cost"][traj])
271                 else:
272                     try:
273                         plt.plot(
274                                 *plot_nodes(t["traj"][traj]),
275                                 color=COLOR["log"][traj],
276                                 label=t["cost"][traj])
277                     except:
278                         plt.plot(
279                                 *plot_nodes(t["traj"][traj]),
280                                 color="black",
281                                 label=t["cost"][traj])
282         except:
283             print("No trajectory")
284     plt.plot(*plot_nodes([s["init"]]), color=COLOR["start"], marker=".")
285     plt.plot(*plot_nodes([s["goal"]]), color=COLOR["goal"], marker=".")
286     # end plot here
287
288     handles, labels = ax.get_legend_handles_labels()
289     #lgd = ax.legend(handles, labels, loc="upper center",
290     #        bbox_to_anchor=(0.5, -0.11), title="Cost")
291
292     # END OF SUBPLOTS
293     plt.show()
294     #plt.savefig("{}.png".format(argv[2]), bbox_extra_artists=(lgd,),
295     #        bbox_inches='tight')
296     plt.close(fig)