]> rtime.felk.cvut.cz Git - hubacji1/iamcar.git/blob - plot.py
Plot sample's frame if defined
[hubacji1/iamcar.git] / plot.py
1 # -*- coding: utf-8 -*-
2 """This scipt loads scenario and result trajectory from files and plots it."""
3 from json import loads
4 from math import cos, pi, sin
5 from matplotlib import pyplot as plt
6 from sys import argv
7
8 HEIGHT = 1.418
9 LENGTH = 4.970
10 SAFETY_DIST = 0
11 WHEEL_BASE = 2.9
12 WIDTH = 1.931
13
14 SCEN_FILE = argv[1]
15 TRAJ_FILE = argv[2]
16 PLOT = {
17         "enable" : True,
18         "node" : False,
19         "edge" : True,
20         "sample" : False,
21         "frame" : False,
22         "path" : False,
23         "traj" : True,
24         }
25 COLOR = {
26         "node": "lightgrey",
27         "edge": "lightgrey",
28         "start": "violet",
29         "goal": "red",
30         "sample": "magenta",
31         "path": "lightgrey",
32         "trajectory": "blue",
33         "trajectory-frame": "lightblue",
34         "obstacle": "black",
35         "log": ("gold", "orange", "blueviolet", "blue", "navy", "black"),
36         }
37
38 def car_frame(pose):
39     """Return `xcoords`, `ycoords` arrays of car frame.
40
41     Keyword arguments:
42     pose -- The pose of a car.
43     """
44     lfx = pose[0]
45     lfx += (WIDTH / 2.0) * cos(pose[2] + pi / 2.0)
46     lfx += WHEEL_BASE * cos(pose[2])
47     lfx += SAFETY_DIST * cos(pose[2])
48
49     lrx = pose[0]
50     lrx += (WIDTH / 2.0) * cos(pose[2] + pi / 2.0)
51     lrx += -SAFETY_DIST * cos(pose[2])
52
53     rrx = pose[0]
54     rrx += (WIDTH / 2.0) * cos(pose[2] - pi / 2.0)
55     rrx += -SAFETY_DIST * cos(pose[2])
56
57     rfx = pose[0]
58     rfx += (WIDTH / 2.0) * cos(pose[2] - pi / 2.0)
59     rfx += WHEEL_BASE * cos(pose[2])
60     rfx += SAFETY_DIST * cos(pose[2])
61
62     lfy = pose[1]
63     lfy += (WIDTH / 2.0) * sin(pose[2] + pi / 2.0)
64     lfy += WHEEL_BASE * sin(pose[2])
65     lfy += SAFETY_DIST * sin(pose[2])
66
67     lry = pose[1]
68     lry += (WIDTH / 2.0) * sin(pose[2] + pi / 2.0)
69     lry += -SAFETY_DIST * sin(pose[2])
70
71     rry = pose[1]
72     rry += (WIDTH / 2.0) * sin(pose[2] - pi / 2.0)
73     rry += -SAFETY_DIST * sin(pose[2])
74
75     rfy = pose[1]
76     rfy += (WIDTH / 2.0) * sin(pose[2] - pi / 2.0)
77     rfy += WHEEL_BASE * sin(pose[2])
78     rfy += SAFETY_DIST * sin(pose[2])
79
80     xcoords = (lfx, lrx, rrx, rfx)
81     ycoords = (lfy, lry, rry, rfy)
82     return (xcoords, ycoords)
83
84 def load_scenario(fname):
85     """Load scenario from file."""
86     if fname is None:
87         raise ValueError("File name as argument needed")
88     with open(fname, "r") as f:
89         scenario = loads(f.read())
90     return scenario
91
92 def load_trajectory(fname):
93     """Load trajectory from file."""
94     if fname is None:
95         raise ValueError("File name as argument needed")
96     with open(fname, "r") as f:
97         trajectory = loads(f.read())
98         return trajectory
99
100 def plot_nodes(nodes=[]):
101     """Return xcoords, ycoords of nodes to plot.
102
103     Keyword arguments:
104     nodes -- The list of nodes to plot.
105     """
106     xcoords = []
107     ycoords = []
108     for n in nodes:
109         xcoords.append(n[0])
110         ycoords.append(n[1])
111     return (xcoords, ycoords)
112
113 def plot_segments(segments=[]):
114     """Return xcoords, ycoords of segments to plot.
115
116     Keyword arguments:
117     segments -- The list of segments to plot.
118     """
119     pass
120
121 if __name__ == "__main__":
122     s = load_scenario(SCEN_FILE)
123     try:
124         t = load_trajectory(TRAJ_FILE) # fixed to trajectories
125     except:
126         pass
127
128     plt.rcParams["font.size"] = 24
129     fig = plt.figure()
130     ax = fig.add_subplot(111)
131     ax.set_aspect("equal")
132     ax.set_title("SCENARIO")
133     ax.set_xlabel("Longitudinal direction [m]")
134     ax.set_ylabel("Lateral direction [m]")
135
136     # plot here
137     for o in s["obst"]:
138         try:
139             plt.plot(*plot_nodes(o["segment"]), color="black")
140         except:
141             pass
142         try:
143             ax.add_artist(plt.Circle((o["circle"][0], o["circle"][1]),
144                     o["circle"][2],
145                     color="black", fill=False))
146         except:
147             pass
148     if PLOT["node"]:
149         try:
150             plt.plot(*plot_nodes(t["node"]), color=COLOR["node"],
151                     marker=".", linestyle = "None")
152         except:
153             print("No RRTNode")
154     if PLOT["edge"]:
155         try:
156             for edges in t["edge"]:
157                 for e in edges:
158                     plt.plot([e[0][0], e[1][0]], [e[0][1], e[1][1]],
159                             color=COLOR["edge"])
160         except:
161             print("No edges")
162     if PLOT["sample"]:
163         try:
164             if PLOT["frame"]:
165                 for i in t["samp"]:
166                     plt.plot(*car_frame(i), color=COLOR["sample"])
167             else:
168                 plt.plot(*plot_nodes(t["samp"]), color=COLOR["sample"],
169                         marker=".", linestyle = "None")
170         except:
171             print("No RRTSample")
172     if PLOT["path"]:
173         try:
174             for path in range(len(t["path"])):
175                 plt.plot(*plot_nodes(t["path"][path]), color=COLOR["path"])
176         except:
177             print("No path")
178     if PLOT["traj"]:
179         try:
180             for traj in range(len(t["traj"])):
181                 if PLOT["frame"]:
182                     for i in t["traj"][traj]:
183                         plt.plot(*car_frame(i), color=COLOR["log"][traj],
184                                 label=t["cost"][traj])
185                 else:
186                     try:
187                         plt.plot(
188                                 *plot_nodes(t["traj"][traj]),
189                                 color=COLOR["log"][traj],
190                                 label=t["cost"][traj])
191                     except:
192                         plt.plot(
193                                 *plot_nodes(t["traj"][traj]),
194                                 color="black",
195                                 label=t["cost"][traj])
196         except:
197             print("No trajectory")
198     plt.plot(*plot_nodes([s["init"]]), color=COLOR["start"], marker=".")
199     plt.plot(*plot_nodes([s["goal"]]), color=COLOR["goal"], marker=".")
200     # end plot here
201
202     handles, labels = ax.get_legend_handles_labels()
203     lgd = ax.legend(handles, labels, loc="upper center",
204             bbox_to_anchor=(0.5, -0.11), title="Cost")
205     plt.show()
206     #plt.savefig("{}.png".format(argv[2]), bbox_extra_artists=(lgd,),
207     #        bbox_inches='tight')
208     plt.close(fig)