]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
e72099880b03843558ca4eb987344ea2a1d25a51
[boost-statechart-viewer.git] / src / visualizer.cpp
1 /** @file */
2 ////////////////////////////////////////////////////////////////////////////////////////
3 //
4 //    This file is part of Boost Statechart Viewer.
5 //
6 //    Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 //    it under the terms of the GNU General Public License as published by
8 //    the Free Software Foundation, either version 3 of the License, or
9 //    (at your option) any later version.
10 //
11 //    Boost Statechart Viewer is distributed in the hope that it will be useful,
12 //    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 //    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 //    GNU General Public License for more details.
15 //
16 //    You should have received a copy of the GNU General Public License
17 //    along with Boost Statechart Viewer.  If not, see <http://www.gnu.org/licenses/>.
18 //
19 ////////////////////////////////////////////////////////////////////////////////////////
20
21 //standard header files
22 #include <iomanip>
23 #include <fstream>
24 #include <map>
25
26 //LLVM Header files
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/Support/raw_os_ostream.h"
29
30 //clang header files
31 #include "clang/AST/ASTConsumer.h"
32 #include "clang/AST/ASTContext.h"
33 #include "clang/AST/CXXInheritance.h"
34 #include "clang/AST/RecursiveASTVisitor.h"
35 #include "clang/Frontend/CompilerInstance.h"
36 #include "clang/Frontend/FrontendPluginRegistry.h"
37
38 using namespace clang;
39 using namespace std;
40
41 namespace Model
42 {
43
44     inline int getIndentLevelIdx() {
45         static int i = ios_base::xalloc();
46         return i;
47     }
48
49     ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
50     ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
51     ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
52
53     class State;
54
55     class Context : public map<string, State*> {
56     public:
57         iterator add(State *state);
58         Context *findContext(const string &name);
59     };
60
61     class State : public Context
62     {
63         string initialInnerState;
64     public:
65         const string name;
66         explicit State(string name) : name(name) {}
67         void setInitialInnerState(string name) { initialInnerState = name; }
68         friend ostream& operator<<(ostream& os, const State& s);
69     };
70
71
72     Context::iterator Context::add(State *state)
73     {
74         pair<iterator, bool> ret =  insert(value_type(state->name, state));
75         return ret.first;
76     }
77
78     Context *Context::findContext(const string &name)
79     {
80         iterator i = find(name), e;
81         if (i != end())
82             return i->second;
83         for (i = begin(), e = end(); i != e; ++i) {
84             Context *c = i->second->findContext(name);
85             if (c)
86                 return c;
87         }
88         return 0;
89     }
90
91
92     ostream& operator<<(ostream& os, const Context& c);
93
94     ostream& operator<<(ostream& os, const State& s)
95     {
96         os << indent << "" << s.name << "\n";
97         if (s.size()) {
98             os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
99             os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
100             os << indent << "label = \"" << s.name << "\"\n";
101             os << indent << s.initialInnerState << " [peripheries=2]\n";
102             os << static_cast<Context>(s);
103             os << indent_dec << indent << "}\n";
104         }
105         return os;
106     }
107
108
109     ostream& operator<<(ostream& os, const Context& c)
110     {
111         for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
112             os << *i->second;
113         }
114         return os;
115     }
116
117
118     class Transition
119     {
120     public:
121         const string src, dst, event;
122         Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
123     };
124
125     ostream& operator<<(ostream& os, const Transition& t)
126     {
127         os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
128         return os;
129     }
130
131
132     class Machine : public Context
133     {
134     protected:
135         string initial_state;
136     public:
137         const string name;
138         explicit Machine(string name) : name(name) {}
139
140         void setInitialState(string name) { initial_state = name; }
141
142         friend ostream& operator<<(ostream& os, const Machine& m);
143     };
144
145     ostream& operator<<(ostream& os, const Machine& m)
146     {
147         os << indent << "subgraph " << m.name << " {\n" << indent_inc;
148         os << indent << m.initial_state << " [peripheries=2]\n";
149         os << static_cast<Context>(m);
150         os << indent_dec << indent << "}\n";
151         return os;
152     }
153
154
155     class Model : public map<string, Machine>
156     {
157         Context unknown;        // For forward-declared state classes
158     public:
159         list< Transition*> transitions;
160
161         iterator add(const Machine &m)
162         {
163             pair<iterator, bool> ret =  insert(value_type(m.name, m));
164             return ret.first;
165         }
166
167         void addUnknownState(State *m)
168         {
169             unknown[m->name] = m;
170         }
171
172
173         Context *findContext(const string &name)
174         {
175             Context::iterator ci = unknown.find(name);
176             if (ci != unknown.end())
177                 return ci->second;
178             iterator i = find(name), e;
179             if (i != end())
180                 return &i->second;
181             for (i = begin(), e = end(); i != e; ++i) {
182                 Context *c = i->second.findContext(name);
183                 if (c)
184                     return c;
185             }
186             return 0;
187         }
188
189         State *removeFromUnknownContexts(const string &name)
190         {
191             Context::iterator ci = unknown.find(name);
192             if (ci == unknown.end())
193                 return 0;
194             unknown.erase(ci);
195             return ci->second;
196         }
197
198         void write_as_dot_file(string fn)
199         {
200             ofstream f(fn.c_str());
201             f << "digraph statecharts {\n" << indent_inc;
202             for (iterator i = begin(), e = end(); i != e; i++)
203                 f << i->second;
204             for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
205                 f << **t;
206             f << indent_dec << "}\n";
207         }
208     };
209 };
210
211
212 class MyCXXRecordDecl : public CXXRecordDecl
213 {
214     static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
215                                     CXXBasePath &Path,
216                                     void *qualName)
217     {
218         string qn(static_cast<const char*>(qualName));
219         const RecordType *rt = Specifier->getType()->getAs<RecordType>();
220         assert(rt);
221         TagDecl *canon = rt->getDecl()->getCanonicalDecl();
222         return canon->getQualifiedNameAsString() == qn;
223     }
224
225 public:
226     bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
227         CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
228         Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
229         if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
230             return false;
231         if (Base)
232             *Base = Paths.front().back().Base;
233         return true;
234     }
235 };
236
237 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
238 {
239     Model::Model &model;
240     const CXXRecordDecl *SrcState;
241     const Type *EventType;
242 public:
243     explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
244         : model(model), SrcState(SrcState), EventType(EventType) {}
245
246     bool VisitMemberExpr(MemberExpr *E) {
247         if (E->getMemberNameInfo().getAsString() != "transit")
248             return true;
249         if (E->hasExplicitTemplateArgs()) {
250             const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
251             CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
252             CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
253             Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
254             model.transitions.push_back(T);
255         }
256         return true;
257     }
258 };
259
260 class Visitor : public RecursiveASTVisitor<Visitor>
261 {
262     ASTContext *ASTCtx;
263     Model::Model &model;
264     DiagnosticsEngine &Diags;
265     unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
266         diag_found_state, diag_found_statemachine, diag_no_history;
267
268 public:
269     bool shouldVisitTemplateInstantiations() const { return true; }
270
271     explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
272         : ASTCtx(Context), model(model), Diags(Diags)
273     {
274         diag_found_statemachine =
275             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
276         diag_found_state =
277             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
278         diag_unhandled_reaction_type =
279             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
280         diag_unhandled_reaction_decl =
281             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
282         diag_unhandled_reaction_decl =
283             Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
284     }
285
286     DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
287
288     void HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
289     {
290         IdentifierInfo& II = ASTCtx->Idents.get("react");
291         // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
292         for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
293              ReactRes.first != ReactRes.second; ++ReactRes.first) {
294             if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first))
295                 if (const ParmVarDecl *p = React->getParamDecl(0)) {
296                     const Type *ParmType = p->getType().getTypePtr();
297                     if (ParmType->isLValueReferenceType())
298                         ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
299                     if (ParmType == EventType)
300                         FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
301                 }
302         }
303     }
304
305     void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
306     {
307         // TODO: Improve Loc tracking
308         if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
309             HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
310         else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
311             string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
312             if (name == "boost::statechart::transition") {
313                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
314                 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
315                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
316                 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
317
318                 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
319                 model.transitions.push_back(T);
320             } else if (name == "boost::statechart::custom_reaction") {
321                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
322                 HandleCustomReaction(SrcState, EventType);
323             } else if (name == "boost::statechart::deferral") {
324                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
325                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
326
327                 Model::Transition *T = new Model::Transition(SrcState->getName(), "\"??? deferral\"", Event->getName());
328                 model.transitions.push_back(T);
329             } else if (name == "boost::mpl::list") {
330                 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
331                     HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
332             } else
333                 Diag(Loc, diag_unhandled_reaction_type) << name;
334         } else
335             Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
336     }
337
338     void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
339     {
340         if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
341             HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
342                            r->getLocStart(), SrcState);
343         else
344             Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
345     }
346
347     CXXRecordDecl *getTemplateArgDecl(const Type *T, unsigned ArgNum)
348     {
349         if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
350             return getTemplateArgDecl(ET->getNamedType().getTypePtr(), ArgNum);
351         else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
352             if (TST->getNumArgs() >= ArgNum+1)
353                 return TST->getArg(ArgNum).getAsType()->getAsCXXRecordDecl();
354         }
355         return 0;
356     }
357
358
359     bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
360     {
361         if (!Declaration->isCompleteDefinition())
362             return true;
363
364         MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
365         const CXXBaseSpecifier *Base;
366
367         if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
368         {
369             string name(RecordDecl->getName()); //getQualifiedNameAsString());
370             Diag(RecordDecl->getLocStart(), diag_found_state) << name;
371
372             Model::State *state;
373             // Either we saw a reference to forward declared state
374             // before, or we create a new state.
375             if (!(state = model.removeFromUnknownContexts(name)))
376                 // TODO: Fix the value of name
377                 state = new Model::State(name);
378
379             CXXRecordDecl *Context = getTemplateArgDecl(Base->getType().getTypePtr(), 1);
380             Model::Context *c = model.findContext(Context->getName());
381             if (!c) {
382                 Model::State *s = new Model::State(Context->getName());
383                 model.addUnknownState(s);
384                 c = s;
385             }
386             c->add(state);
387
388             if (CXXRecordDecl *InnerInitialState = getTemplateArgDecl(Base->getType().getTypePtr(), 2))
389                 state->setInitialInnerState(InnerInitialState->getName());
390
391 //          if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
392 //              Diag(History->getLocStart(), diag_no_history);
393
394             IdentifierInfo& II = ASTCtx->Idents.get("reactions");
395             // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
396             for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
397                  Reactions.first != Reactions.second; ++Reactions.first)
398                 HandleReaction(*Reactions.first, RecordDecl);
399         }
400         else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
401         {
402             Model::Machine m(RecordDecl->getName());
403             Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
404
405             if (CXXRecordDecl *InitialState = getTemplateArgDecl(Base->getType().getTypePtr(), 1))
406                 m.setInitialState(InitialState->getName());
407             model.add(m);
408         }
409         else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
410         {
411             //sc.events.push_back(RecordDecl->getNameAsString());
412         }
413         return true;
414     }
415 };
416
417
418 class VisualizeStatechartConsumer : public clang::ASTConsumer
419 {
420     Model::Model model;
421     Visitor visitor;
422     string destFileName;
423 public:
424     explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
425                                          DiagnosticsEngine &D)
426         : visitor(Context, model, D), destFileName(destFileName) {}
427
428     virtual void HandleTranslationUnit(clang::ASTContext &Context) {
429         visitor.TraverseDecl(Context.getTranslationUnitDecl());
430         model.write_as_dot_file(destFileName);
431     }
432 };
433
434 class VisualizeStatechartAction : public PluginASTAction
435 {
436 protected:
437   ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
438     size_t dot = getCurrentFile().find_last_of('.');
439     std::string dest = getCurrentFile().substr(0, dot);
440     dest.append(".dot");
441     return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
442   }
443
444   bool ParseArgs(const CompilerInstance &CI,
445                  const std::vector<std::string>& args) {
446     for (unsigned i = 0, e = args.size(); i != e; ++i) {
447       llvm::errs() << "Visualizer arg = " << args[i] << "\n";
448
449       // Example error handling.
450       if (args[i] == "-an-error") {
451         DiagnosticsEngine &D = CI.getDiagnostics();
452         unsigned DiagID = D.getCustomDiagID(
453           DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
454         D.Report(DiagID);
455         return false;
456       }
457     }
458     if (args.size() && args[0] == "help")
459       PrintHelp(llvm::errs());
460
461     return true;
462   }
463   void PrintHelp(llvm::raw_ostream& ros) {
464     ros << "Help for Visualize Statechart plugin goes here\n";
465   }
466
467 };
468
469 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");
470
471 // Local Variables:
472 // c-basic-offset: 4
473 // End: