]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
Visualization of states
[boost-statechart-viewer.git] / src / visualizer.cpp
1 /** @file */
2 ////////////////////////////////////////////////////////////////////////////////////////
3 //
4 //    This file is part of Boost Statechart Viewer.
5 //
6 //    Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 //    it under the terms of the GNU General Public License as published by
8 //    the Free Software Foundation, either version 3 of the License, or
9 //    (at your option) any later version.
10 //
11 //    Boost Statechart Viewer is distributed in the hope that it will be useful,
12 //    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 //    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 //    GNU General Public License for more details.
15 //
16 //    You should have received a copy of the GNU General Public License
17 //    along with Boost Statechart Viewer.  If not, see <http://www.gnu.org/licenses/>.
18 //
19 ////////////////////////////////////////////////////////////////////////////////////////
20
21 //standard header files
22 #include <iomanip>
23 #include <fstream>
24 #include <map>
25
26 //LLVM Header files
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/Support/raw_os_ostream.h"
29
30 //clang header files
31 #include "clang/AST/ASTConsumer.h"
32 #include "clang/AST/ASTContext.h"
33 #include "clang/AST/CXXInheritance.h"
34 #include "clang/AST/RecursiveASTVisitor.h"
35 #include "clang/Frontend/CompilerInstance.h"
36 #include "clang/Frontend/FrontendPluginRegistry.h"
37
38 using namespace clang;
39 using namespace std;
40
41 namespace Model
42 {
43
44     inline int getIndentLevelIdx() {
45         static int i = ios_base::xalloc();
46         return i;
47     }
48
49     ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
50     ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
51     ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
52
53     class State;
54
55     class Context : public map<string, State*> {
56     public:
57         iterator add(State *state);
58         Context *findContext(const string &name);
59     };
60
61     class State : public Context
62     {
63         string initialInnerState;
64         list<string> defferedEvents;
65     public:
66         const string name;
67         explicit State(string name) : name(name) {}
68         void setInitialInnerState(string name) { initialInnerState = name; }
69         void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
70         friend ostream& operator<<(ostream& os, const State& s);
71     };
72
73
74     Context::iterator Context::add(State *state)
75     {
76         pair<iterator, bool> ret =  insert(value_type(state->name, state));
77         return ret.first;
78     }
79
80     Context *Context::findContext(const string &name)
81     {
82         iterator i = find(name), e;
83         if (i != end())
84             return i->second;
85         for (i = begin(), e = end(); i != e; ++i) {
86             Context *c = i->second->findContext(name);
87             if (c)
88                 return c;
89         }
90         return 0;
91     }
92
93     ostream& operator<<(ostream& os, const Context& c);
94
95     ostream& operator<<(ostream& os, const State& s)
96     {
97         string label = s.name;
98         for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
99             label.append("<br />").append(*i).append(" / defer");
100         os << indent << s.name << " [label=<" << label << ">]\n";
101         if (s.size()) {
102             os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
103             os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
104             os << indent << "label = \"" << s.name << "\"\n";
105             os << indent << s.initialInnerState << " [peripheries=2]\n";
106             os << static_cast<Context>(s);
107             os << indent_dec << indent << "}\n";
108         }
109         return os;
110     }
111
112
113     ostream& operator<<(ostream& os, const Context& c)
114     {
115         for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
116             os << *i->second;
117         }
118         return os;
119     }
120
121
122     class Transition
123     {
124     public:
125         const string src, dst, event;
126         Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
127     };
128
129     ostream& operator<<(ostream& os, const Transition& t)
130     {
131         os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
132         return os;
133     }
134
135
136     class Machine : public Context
137     {
138     protected:
139         string initial_state;
140     public:
141         const string name;
142         explicit Machine(string name) : name(name) {}
143
144         void setInitialState(string name) { initial_state = name; }
145
146         friend ostream& operator<<(ostream& os, const Machine& m);
147     };
148
149     ostream& operator<<(ostream& os, const Machine& m)
150     {
151         os << indent << "subgraph " << m.name << " {\n" << indent_inc;
152         os << indent << m.initial_state << " [peripheries=2]\n";
153         os << static_cast<Context>(m);
154         os << indent_dec << indent << "}\n";
155         return os;
156     }
157
158
159     class Model : public map<string, Machine>
160     {
161         Context undefined;      // For forward-declared state classes
162     public:
163         list< Transition*> transitions;
164
165         iterator add(const Machine &m)
166         {
167             pair<iterator, bool> ret =  insert(value_type(m.name, m));
168             return ret.first;
169         }
170
171         void addUndefinedState(State *m)
172         {
173             undefined[m->name] = m;
174         }
175
176
177         Context *findContext(const string &name)
178         {
179             Context::iterator ci = undefined.find(name);
180             if (ci != undefined.end())
181                 return ci->second;
182             iterator i = find(name), e;
183             if (i != end())
184                 return &i->second;
185             for (i = begin(), e = end(); i != e; ++i) {
186                 Context *c = i->second.findContext(name);
187                 if (c)
188                     return c;
189             }
190             return 0;
191         }
192
193         State *findState(const string &name)
194         {
195             for (iterator i = begin(), e = end(); i != e; ++i) {
196                 Context *c = i->second.findContext(name);
197                 if (c)
198                     return static_cast<State*>(c);
199             }
200             return 0;
201         }
202
203
204         State *removeFromUndefinedContexts(const string &name)
205         {
206             Context::iterator ci = undefined.find(name);
207             if (ci == undefined.end())
208                 return 0;
209             undefined.erase(ci);
210             return ci->second;
211         }
212
213         void write_as_dot_file(string fn)
214         {
215             ofstream f(fn.c_str());
216             f << "digraph statecharts {\n" << indent_inc;
217             for (iterator i = begin(), e = end(); i != e; i++)
218                 f << i->second;
219             for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
220                 f << **t;
221             f << indent_dec << "}\n";
222         }
223     };
224 };
225
226
227 class MyCXXRecordDecl : public CXXRecordDecl
228 {
229     static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
230                                     CXXBasePath &Path,
231                                     void *qualName)
232     {
233         string qn(static_cast<const char*>(qualName));
234         const RecordType *rt = Specifier->getType()->getAs<RecordType>();
235         assert(rt);
236         TagDecl *canon = rt->getDecl()->getCanonicalDecl();
237         return canon->getQualifiedNameAsString() == qn;
238     }
239
240 public:
241     bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
242         CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
243         Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
244         if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
245             return false;
246         if (Base)
247             *Base = Paths.front().back().Base;
248         return true;
249     }
250 };
251
252 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
253 {
254     Model::Model &model;
255     const CXXRecordDecl *SrcState;
256     const Type *EventType;
257 public:
258     explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
259         : model(model), SrcState(SrcState), EventType(EventType) {}
260
261     bool VisitMemberExpr(MemberExpr *E) {
262         if (E->getMemberNameInfo().getAsString() != "transit")
263             return true;
264         if (E->hasExplicitTemplateArgs()) {
265             const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
266             CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
267             CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
268             Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
269             model.transitions.push_back(T);
270         }
271         return true;
272     }
273 };
274
275 class Visitor : public RecursiveASTVisitor<Visitor>
276 {
277     ASTContext *ASTCtx;
278     Model::Model &model;
279     DiagnosticsEngine &Diags;
280     unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
281         diag_found_state, diag_found_statemachine, diag_no_history, diag_warning;
282
283 public:
284     bool shouldVisitTemplateInstantiations() const { return true; }
285
286     explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
287         : ASTCtx(Context), model(model), Diags(Diags)
288     {
289         diag_found_statemachine =
290             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
291         diag_found_state =
292             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
293         diag_unhandled_reaction_type =
294             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
295         diag_unhandled_reaction_decl =
296             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
297         diag_unhandled_reaction_decl =
298             Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
299         diag_warning =
300             Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
301     }
302
303     DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
304
305     void HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
306     {
307         IdentifierInfo& II = ASTCtx->Idents.get("react");
308         // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
309         for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
310              ReactRes.first != ReactRes.second; ++ReactRes.first) {
311             if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first)) {
312                 if (React->getNumParams() >= 1) {
313                     const ParmVarDecl *p = React->getParamDecl(0);
314                     const Type *ParmType = p->getType().getTypePtr();
315                     if (ParmType->isLValueReferenceType())
316                         ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
317                     if (ParmType == EventType)
318                         FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
319                 } else
320                     Diag(React->getLocStart(), diag_warning)
321                         << React << "has not a parameter";
322             } else
323                 Diag((*ReactRes.first)->getSourceRange().getBegin(), diag_warning)
324                     << (*ReactRes.first)->getDeclKindName() << "is not supported as react method";
325         }
326     }
327
328     void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
329     {
330         // TODO: Improve Loc tracking
331         if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
332             HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
333         else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
334             string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
335             if (name == "boost::statechart::transition") {
336                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
337                 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
338                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
339                 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
340
341                 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
342                 model.transitions.push_back(T);
343             } else if (name == "boost::statechart::custom_reaction") {
344                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
345                 HandleCustomReaction(SrcState, EventType);
346             } else if (name == "boost::statechart::deferral") {
347                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
348                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
349
350                 Model::State *s = model.findState(SrcState->getName());
351                 assert(s);
352                 s->addDeferredEvent(Event->getName());
353             } else if (name == "boost::mpl::list") {
354                 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
355                     HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
356             } else
357                 Diag(Loc, diag_unhandled_reaction_type) << name;
358         } else
359             Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
360     }
361
362     void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
363     {
364         if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
365             HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
366                            r->getLocStart(), SrcState);
367         else
368             Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
369     }
370
371     TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum)
372     {
373         if (const ElaboratedTypeLoc *ET = dyn_cast<ElaboratedTypeLoc>(&T))
374             return getTemplateArgLoc(ET->getNamedTypeLoc(), ArgNum);
375         else if (const TemplateSpecializationTypeLoc *TST = dyn_cast<TemplateSpecializationTypeLoc>(&T)) {
376             if (TST->getNumArgs() >= ArgNum+1) {
377                 return TST->getArgLoc(ArgNum);
378             } else
379                 Diag(TST->getBeginLoc(), diag_warning) << TST->getType()->getTypeClassName() << "has not enough arguments" << TST->getSourceRange();
380         } else
381             Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
382         return TemplateArgumentLoc();
383     }
384
385     TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum) {
386         return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), 1);
387     }
388
389     CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc) {
390         Loc = getTemplateArgLocOfBase(Base, 1);
391         switch (Loc.getArgument().getKind()) {
392         case TemplateArgument::Type:
393             return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
394         case TemplateArgument::Null:
395             // Diag() was already called
396             break;
397         default:
398             Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
399         }
400         return 0;
401     }
402
403     CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum) {
404         TemplateArgumentLoc Loc;
405         return getTemplateArgDeclOfBase(Base, ArgNum, Loc);
406     }
407
408     void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
409     {
410         string name(RecordDecl->getName()); //getQualifiedNameAsString());
411         Diag(RecordDecl->getLocStart(), diag_found_state) << name;
412
413         Model::State *state;
414         // Either we saw a reference to forward declared state
415         // before, or we create a new state.
416         if (!(state = model.removeFromUndefinedContexts(name)))
417             state = new Model::State(name);
418
419         CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
420         if (Context) {
421             Model::Context *c = model.findContext(Context->getName());
422             if (!c) {
423                 Model::State *s = new Model::State(Context->getName());
424                 model.addUndefinedState(s);
425                 c = s;
426             }
427             c->add(state);
428         }
429
430         TemplateArgumentLoc Loc;
431         if (MyCXXRecordDecl *InnerInitialState =
432             static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc))) {
433             if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
434                 InnerInitialState->isDerivedFrom("boost::statechart::state_machine"))
435                 state->setInitialInnerState(InnerInitialState->getName());
436             else
437                 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getLocStart(), diag_warning)
438                     << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
439         }
440
441 //          if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
442 //              Diag(History->getLocStart(), diag_no_history);
443
444         IdentifierInfo& II = ASTCtx->Idents.get("reactions");
445         // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
446         // TODO: Find when state has no reactions
447         for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
448              Reactions.first != Reactions.second; ++Reactions.first)
449             HandleReaction(*Reactions.first, RecordDecl);
450     }
451
452     void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
453     {
454         Model::Machine m(RecordDecl->getName());
455         Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
456
457         if (MyCXXRecordDecl *InitialState =
458             static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
459             m.setInitialState(InitialState->getName());
460         model.add(m);
461     }
462
463     bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
464     {
465         if (!Declaration->isCompleteDefinition())
466             return true;
467         if (Declaration->getQualifiedNameAsString() == "boost::statechart::state")
468             return true; // This is an "abstract class" not a real state
469         if (Declaration->getQualifiedNameAsString() == "TimedState")
470             return true; // This is an "abstract class" not a real state
471         if (Declaration->getQualifiedNameAsString() == "TimedSimpleState")
472             return true; // This is an "abstract class" not a real state
473
474
475         MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
476         const CXXBaseSpecifier *Base;
477
478         if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
479             handleSimpleState(RecordDecl, Base);
480         else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
481             handleStateMachine(RecordDecl, Base);
482         else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
483         {
484             //sc.events.push_back(RecordDecl->getNameAsString());
485         }
486         return true;
487     }
488 };
489
490
491 class VisualizeStatechartConsumer : public clang::ASTConsumer
492 {
493     Model::Model model;
494     Visitor visitor;
495     string destFileName;
496 public:
497     explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
498                                          DiagnosticsEngine &D)
499         : visitor(Context, model, D), destFileName(destFileName) {}
500
501     virtual void HandleTranslationUnit(clang::ASTContext &Context) {
502         visitor.TraverseDecl(Context.getTranslationUnitDecl());
503         model.write_as_dot_file(destFileName);
504     }
505 };
506
507 class VisualizeStatechartAction : public PluginASTAction
508 {
509 protected:
510   ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
511     size_t dot = getCurrentFile().find_last_of('.');
512     std::string dest = getCurrentFile().substr(0, dot);
513     dest.append(".dot");
514     return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
515   }
516
517   bool ParseArgs(const CompilerInstance &CI,
518                  const std::vector<std::string>& args) {
519     for (unsigned i = 0, e = args.size(); i != e; ++i) {
520       llvm::errs() << "Visualizer arg = " << args[i] << "\n";
521
522       // Example error handling.
523       if (args[i] == "-an-error") {
524         DiagnosticsEngine &D = CI.getDiagnostics();
525         unsigned DiagID = D.getCustomDiagID(
526           DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
527         D.Report(DiagID);
528         return false;
529       }
530     }
531     if (args.size() && args[0] == "help")
532       PrintHelp(llvm::errs());
533
534     return true;
535   }
536   void PrintHelp(llvm::raw_ostream& ros) {
537     ros << "Help for Visualize Statechart plugin goes here\n";
538   }
539
540 };
541
542 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");
543
544 // Local Variables:
545 // c-basic-offset: 4
546 // End: