]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
0d4998f6f19513479cd6b205d7e5da3adeaf26f4
[boost-statechart-viewer.git] / src / visualizer.cpp
1 /** @file */
2 ////////////////////////////////////////////////////////////////////////////////////////
3 //
4 //    This file is part of Boost Statechart Viewer.
5 //
6 //    Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 //    it under the terms of the GNU General Public License as published by
8 //    the Free Software Foundation, either version 3 of the License, or
9 //    (at your option) any later version.
10 //
11 //    Boost Statechart Viewer is distributed in the hope that it will be useful,
12 //    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 //    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 //    GNU General Public License for more details.
15 //
16 //    You should have received a copy of the GNU General Public License
17 //    along with Boost Statechart Viewer.  If not, see <http://www.gnu.org/licenses/>.
18 //
19 ////////////////////////////////////////////////////////////////////////////////////////
20
21 //standard header files
22 #include <fstream>
23
24 //LLVM Header files
25 #include "llvm/Support/raw_ostream.h"
26 #include "llvm/Support/raw_os_ostream.h"
27
28 //clang header files
29 #include "clang/AST/ASTConsumer.h"
30 #include "clang/AST/CXXInheritance.h"
31 #include "clang/AST/RecursiveASTVisitor.h"
32 #include "clang/Frontend/CompilerInstance.h"
33 #include "clang/Frontend/FrontendPluginRegistry.h"
34
35 using namespace clang;
36 using namespace std;
37
38 class Statechart
39 {
40 public:
41     class Transition
42     {
43     public:
44         string src, dst, event;
45         Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
46     };
47     string name;
48     string name_of_start;
49     list<Transition> transitions;
50     list<string> cReactions; /** list of custom reactions. After all files are traversed this list should be empty. */
51     list<string> events;
52     list<string> states;
53
54     void write_dot_file(string fn)
55     {
56         ofstream f(fn.c_str());
57         f << "digraph " << name << " {\n";
58         f << "  " << name_of_start << " [peripheries=2]\n";
59         for (string& s : states) {
60             f << "  " << s << "\n";
61         }
62
63         for (Transition &t : transitions) {
64             f << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
65         }
66
67         f << "}";
68     }
69 };
70
71
72 class MyCXXRecordDecl : public CXXRecordDecl
73 {
74     static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
75                                     CXXBasePath &Path,
76                                     void *qualName)
77     {
78         string qn(static_cast<const char*>(qualName));
79         const RecordType *rt = Specifier->getType()->getAs<RecordType>();
80         assert(rt);
81         TagDecl *canon = rt->getDecl()->getCanonicalDecl();
82         return canon->getQualifiedNameAsString() == qn;
83     }
84
85 public:
86     bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
87         CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
88         Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
89         if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
90             return false;
91         if (Base)
92             *Base = Paths.front().back().Base;
93         return true;
94     }
95 };
96
97
98 class Visitor : public RecursiveASTVisitor<Visitor>
99 {
100     ASTContext *Context;
101     Statechart &sc;
102     DiagnosticsEngine &Diags;
103     unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
104         diag_found_state, diag_found_statemachine;
105
106 public:
107     bool shouldVisitTemplateInstantiations() const { return true; }
108
109     explicit Visitor(ASTContext *Context, Statechart &sc, DiagnosticsEngine &Diags)
110         : Context(Context), sc(sc), Diags(Diags)
111     {
112         diag_found_statemachine =
113             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
114         diag_found_state =
115             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
116         diag_unhandled_reaction_type =
117             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
118         diag_unhandled_reaction_decl =
119             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
120     }
121
122     DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
123
124     void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
125     {
126         if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
127             HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
128         else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
129             string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
130             if (name == "boost::statechart::transition") {
131                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
132                 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
133                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
134                 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
135
136                 sc.transitions.push_back(Statechart::Transition(SrcState->getName(), DstState->getName(),
137                                                                 Event->getName()));
138             } else if (name == "boost::mpl::list") {
139                 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
140                     HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
141             }
142             //->getDecl()->getQualifiedNameAsString();
143         } else
144             Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
145     }
146
147     void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
148     {
149         if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
150             HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
151                            r->getLocStart(), SrcState);
152         else
153             Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
154     }
155
156
157     bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
158     {
159         if (!Declaration->isCompleteDefinition())
160             return true;
161
162         MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
163         const CXXBaseSpecifier *Base;
164
165         if (RecordDecl->isDerivedFrom("boost::statechart::simple_state"))
166         {
167             string state(RecordDecl->getName()); //getQualifiedNameAsString());
168             Diag(RecordDecl->getLocStart(), diag_found_state) << state;
169             sc.states.push_back(state);
170
171             IdentifierInfo& II = Context->Idents.get("reactions");
172             // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
173             for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
174                  Reactions.first != Reactions.second; ++Reactions.first)
175                 HandleReaction(*Reactions.first, RecordDecl);
176         }
177         else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
178         {
179             sc.name = RecordDecl->getQualifiedNameAsString();
180             Diag(RecordDecl->getLocStart(), diag_found_statemachine) << sc.name;
181
182             if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(Base->getType())) {
183                 if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(ET->getNamedType())) {
184                     sc.name_of_start = TST->getArg(1).getAsType()->getAsCXXRecordDecl()->getName();
185                 }
186             }
187         }
188         else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
189         {
190             sc.events.push_back(RecordDecl->getNameAsString());
191         }
192         return true;
193     }
194 };
195
196
197 class VisualizeStatechartConsumer : public clang::ASTConsumer
198 {
199     Statechart statechart;
200     Visitor visitor;
201     string destFileName;
202 public:
203     explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
204                                          DiagnosticsEngine &D)
205         : visitor(Context, statechart, D), destFileName(destFileName) {}
206
207     virtual void HandleTranslationUnit(clang::ASTContext &Context) {
208         visitor.TraverseDecl(Context.getTranslationUnitDecl());
209         statechart.write_dot_file(destFileName);
210     }
211 };
212
213 class VisualizeStatechartAction : public PluginASTAction
214 {
215 protected:
216   ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
217     size_t dot = getCurrentFile().find_last_of('.');
218     std::string dest = getCurrentFile().substr(0, dot);
219     dest.append(".dot");
220     return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
221   }
222
223   bool ParseArgs(const CompilerInstance &CI,
224                  const std::vector<std::string>& args) {
225     for (unsigned i = 0, e = args.size(); i != e; ++i) {
226       llvm::errs() << "Visualizer arg = " << args[i] << "\n";
227
228       // Example error handling.
229       if (args[i] == "-an-error") {
230         DiagnosticsEngine &D = CI.getDiagnostics();
231         unsigned DiagID = D.getCustomDiagID(
232           DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
233         D.Report(DiagID);
234         return false;
235       }
236     }
237     if (args.size() && args[0] == "help")
238       PrintHelp(llvm::errs());
239
240     return true;
241   }
242   void PrintHelp(llvm::raw_ostream& ros) {
243     ros << "Help for Visualize Statechart plugin goes here\n";
244   }
245
246 };
247
248 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");
249
250 // Local Variables:
251 // c-basic-offset: 4
252 // End: