]> rtime.felk.cvut.cz Git - hubacji1/iamcar2.git/blob - scripts/print.py
Add multiple costs plot function
[hubacji1/iamcar2.git] / scripts / print.py
1 """Procedures for printing scenario results."""
2 import sys
3 import matplotlib.pyplot as plt
4 import numpy as np
5 import scipy.stats as ss
6
7 import scenario
8
9 LATEX = False
10
11 def mean_conf_int(data, conf=0.95):
12     """Return (mean, lower, uppper) of data.
13
14     see https://stackoverflow.com/questions/15033511/compute-a-confidence-interval-from-sample-data
15
16     Keyword arguments:
17     data -- A list of data.
18     conf -- Confidence interval.
19     """
20     a = np.array(data)
21     n = len(a)
22     m = np.mean(a)
23     se = ss.sem(a)
24     h = se * ss.t.ppf((1 + conf) / 2, n - 1)
25     return (m, m - h, m + h)
26
27 def infoprint(w={}, t=""):
28     """Print statistic information about scenario results.
29
30     Keyword arguments:
31     w -- What to print.
32     t -- Print title.
33     """
34     if LATEX:
35         print("\\begin{table*}[h]")
36     else:
37         print(t)
38     if LATEX:
39         print("\\begin{tabular}{", end="")
40         print("|c|", end="")
41         for k, v in w.items():
42             print("c|", end="")
43         print("}", end="")
44         print()
45     if LATEX:
46         print("\hline")
47     if LATEX:
48         print("{:<14}".format("Scenarios:"), end="")
49     else:
50         print("{:<12}".format("Scenarios:"), end="")
51     for k, v in w.items():
52         print(
53             "{} {:<8}".format(" &" if LATEX else "", k),
54             end="",
55         )
56     print("\\\\" if LATEX else "")
57     if LATEX:
58         print("\hline")
59     print("{:<12}".format("Mean:"), end="")
60     for k, v in w.items():
61         print(
62             "{} {:<8.2f}".format(" &" if LATEX else "", np.mean(v)),
63             end="",
64         )
65     print("\\\\" if LATEX else "")
66     print("{:<12}".format("0.95 low:"), end="")
67     for k, v in w.items():
68         print(
69             "{} {:<8.2f}".format(" &" if LATEX else "", mean_conf_int(v)[1]),
70             end="",
71         )
72     print("\\\\" if LATEX else "")
73     print("{:<12}".format("0.95 high:"), end="")
74     for k, v in w.items():
75         print(
76             "{} {:<8.2f}".format(" &" if LATEX else "", mean_conf_int(v)[2]),
77             end="",
78         )
79     print("\\\\" if LATEX else "")
80     print("{:<12}".format("Median:"), end="")
81     for k, v in w.items():
82         print(
83             "{} {:<8.2f}".format(" &" if LATEX else "", np.median(v)),
84             end="",
85         )
86     print("\\\\" if LATEX else "")
87     print("{:<12}".format("0.95 perc.:"), end="")
88     for k, v in w.items():
89         print(
90             "{} {:<8.2f}".format(
91                 " &" if LATEX else "",
92                 np.percentile(v, [95])[0],
93             ),
94             end="",
95         )
96     print("\\\\" if LATEX else "")
97     print("{:<12}".format("Minimum:"), end="")
98     for k, v in w.items():
99         print(
100             "{} {:<8.2f}".format(" &" if LATEX else "", np.min(v)),
101             end="",
102         )
103     print("\\\\" if LATEX else "")
104     print("{:<12}".format("Maximum:"), end="")
105     for k, v in w.items():
106         print(
107             "{} {:<8.2f}".format(" &" if LATEX else "", np.max(v)),
108             end="",
109         )
110     print("\\\\" if LATEX else "")
111     if LATEX:
112         print("\hline")
113     if LATEX:
114         print("\\end{tabular}")
115         print("\\caption{", end="")
116         print("{}".format(t), end="")
117         print("}")
118         print("\\end{table*}")
119
120 def error_infoprint(w={}, t=""):
121     """Print error information about scenario results.
122
123     Keyword arguments:
124     w -- What to print.
125     t -- Print title.
126     """
127     print(t)
128     w = w[0]
129     print("{:<12}".format("Scenarios:"), end="")
130     for k, v in w.items():
131         print(" {:<8}".format(k), end="")
132     print()
133     print("{:<12}".format("Rate:"), end="")
134     for k, v in w.items():
135         print(" {:<8.2f}".format(v), end="")
136     print()
137
138 if __name__ == "__main__":
139     if len(sys.argv) > 1:
140         w = sys.argv[1]
141     else:
142         w = "time"
143     if len(sys.argv) > 2:
144         scenario.DNAME = sys.argv[2]
145
146     plt.rcParams["font.size"] = 22
147     plt.rcParams["font.family"] = "sans-serif"
148     plt.rcParams["figure.figsize"] = [12, 4]
149
150     if w == "time":
151         infoprint(scenario.time(), "Elapsed time")
152     elif w == "cost":
153         infoprint(scenario.cost(), "Final path cost")
154     elif w == "orig_cost":
155         infoprint(scenario.orig_cost(), "Original path cost")
156     elif w == "cusp":
157         infoprint(scenario.cusp(), "Changes in direction")
158     elif w == "orig_cusp":
159         infoprint(scenario.orig_cusp(), "Changes in direction")
160     elif w == "error":
161         error_infoprint(scenario.error_rate(), "Error rate")
162     elif w == "iter":
163         infoprint(
164             scenario.iter(),
165             "Number of iterations",
166         )
167     else:
168         print("""The following arguments are allowed:
169
170         time, cost, orig_cost, cusp, orig_cusp, error, iter
171         """)