]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
Error when missing react method for custom_reaction mentioned in typedef.
[boost-statechart-viewer.git] / src / visualizer.cpp
1 /** @file */
2 ////////////////////////////////////////////////////////////////////////////////////////
3 //
4 //    This file is part of Boost Statechart Viewer.
5 //
6 //    Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 //    it under the terms of the GNU General Public License as published by
8 //    the Free Software Foundation, either version 3 of the License, or
9 //    (at your option) any later version.
10 //
11 //    Boost Statechart Viewer is distributed in the hope that it will be useful,
12 //    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 //    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 //    GNU General Public License for more details.
15 //
16 //    You should have received a copy of the GNU General Public License
17 //    along with Boost Statechart Viewer.  If not, see <http://www.gnu.org/licenses/>.
18 //
19 ////////////////////////////////////////////////////////////////////////////////////////
20
21 //standard header files
22 #include <iomanip>
23 #include <fstream>
24 #include <map>
25
26 //LLVM Header files
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/Support/raw_os_ostream.h"
29
30 //clang header files
31 #include "clang/AST/ASTConsumer.h"
32 #include "clang/AST/ASTContext.h"
33 #include "clang/AST/CXXInheritance.h"
34 #include "clang/AST/RecursiveASTVisitor.h"
35 #include "clang/Frontend/CompilerInstance.h"
36 #include "clang/Frontend/FrontendPluginRegistry.h"
37
38 using namespace clang;
39 using namespace std;
40
41 namespace Model
42 {
43
44     inline int getIndentLevelIdx() {
45         static int i = ios_base::xalloc();
46         return i;
47     }
48
49     ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
50     ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
51     ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
52
53     class State;
54
55     class Context : public map<string, State*> {
56     public:
57         iterator add(State *state);
58         Context *findContext(const string &name);
59     };
60
61     class State : public Context
62     {
63         string initialInnerState;
64         list<string> defferedEvents;
65         bool noTypedef;
66     public:
67         const string name;
68         explicit State(string name) : noTypedef(false), name(name) {}
69         void setInitialInnerState(string name) { initialInnerState = name; }
70         void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
71         void setNoTypedef() { noTypedef = true;}
72         friend ostream& operator<<(ostream& os, const State& s);
73     };
74
75
76     Context::iterator Context::add(State *state)
77     {
78         pair<iterator, bool> ret =  insert(value_type(state->name, state));
79         return ret.first;
80     }
81
82     Context *Context::findContext(const string &name)
83     {
84         iterator i = find(name), e;
85         if (i != end())
86             return i->second;
87         for (i = begin(), e = end(); i != e; ++i) {
88             Context *c = i->second->findContext(name);
89             if (c)
90                 return c;
91         }
92         return 0;
93     }
94
95     ostream& operator<<(ostream& os, const Context& c);
96
97     ostream& operator<<(ostream& os, const State& s)
98     {
99         string label = s.name;
100         for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
101             label.append("<br />").append(*i).append(" / defer");
102         if (s.noTypedef) os << indent << s.name << " [label=<" << label << ">, color=\"red\"]\n";
103         else os << indent << s.name << " [label=<" << label << ">]\n";
104         if (s.size()) {
105             os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
106             os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
107             os << indent << "label = \"" << s.name << "\"\n";
108             os << indent << s.initialInnerState << " [peripheries=2]\n";
109             os << static_cast<Context>(s);
110             os << indent_dec << indent << "}\n";
111         }
112         return os;
113     }
114
115
116     ostream& operator<<(ostream& os, const Context& c)
117     {
118         for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
119             os << *i->second;
120         }
121         return os;
122     }
123
124
125     class Transition
126     {
127     public:
128         const string src, dst, event;
129         Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
130     };
131
132     ostream& operator<<(ostream& os, const Transition& t)
133     {
134         os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
135         return os;
136     }
137
138
139     class Machine : public Context
140     {
141     protected:
142         string initial_state;
143     public:
144         const string name;
145         explicit Machine(string name) : name(name) {}
146
147         void setInitialState(string name) { initial_state = name; }
148
149         friend ostream& operator<<(ostream& os, const Machine& m);
150     };
151
152     ostream& operator<<(ostream& os, const Machine& m)
153     {
154         os << indent << "subgraph " << m.name << " {\n" << indent_inc;
155         os << indent << m.initial_state << " [peripheries=2]\n";
156         os << static_cast<Context>(m);
157         os << indent_dec << indent << "}\n";
158         return os;
159     }
160
161
162     class Model : public map<string, Machine>
163     {
164         Context undefined;      // For forward-declared state classes
165     public:
166         list< Transition*> transitions;
167
168         iterator add(const Machine &m)
169         {
170             pair<iterator, bool> ret =  insert(value_type(m.name, m));
171             return ret.first;
172         }
173
174         void addUndefinedState(State *m)
175         {
176             undefined[m->name] = m;
177         }
178
179
180         Context *findContext(const string &name)
181         {
182             Context::iterator ci = undefined.find(name);
183             if (ci != undefined.end())
184                 return ci->second;
185             iterator i = find(name), e;
186             if (i != end())
187                 return &i->second;
188             for (i = begin(), e = end(); i != e; ++i) {
189                 Context *c = i->second.findContext(name);
190                 if (c)
191                     return c;
192             }
193             return 0;
194         }
195
196         State *findState(const string &name)
197         {
198             for (iterator i = begin(), e = end(); i != e; ++i) {
199                 Context *c = i->second.findContext(name);
200                 if (c)
201                     return static_cast<State*>(c);
202             }
203             return 0;
204         }
205
206
207         State *removeFromUndefinedContexts(const string &name)
208         {
209             Context::iterator ci = undefined.find(name);
210             if (ci == undefined.end())
211                 return 0;
212             undefined.erase(ci);
213             return ci->second;
214         }
215
216         void write_as_dot_file(string fn)
217         {
218             ofstream f(fn.c_str());
219             f << "digraph statecharts {\n" << indent_inc;
220             for (iterator i = begin(), e = end(); i != e; i++)
221                 f << i->second;
222             for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
223                 f << **t;
224             f << indent_dec << "}\n";
225         }
226     };
227 };
228
229
230 class MyCXXRecordDecl : public CXXRecordDecl
231 {
232     static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
233                                     CXXBasePath &Path,
234                                     void *qualName)
235     {
236         string qn(static_cast<const char*>(qualName));
237         const RecordType *rt = Specifier->getType()->getAs<RecordType>();
238         assert(rt);
239         TagDecl *canon = rt->getDecl()->getCanonicalDecl();
240         return canon->getQualifiedNameAsString() == qn;
241     }
242
243 public:
244     bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
245         CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
246         Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
247         if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
248             return false;
249         if (Base)
250             *Base = Paths.front().back().Base;
251         return true;
252     }
253 };
254
255 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
256 {
257     Model::Model &model;
258     const CXXRecordDecl *SrcState;
259     const Type *EventType;
260 public:
261     explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
262         : model(model), SrcState(SrcState), EventType(EventType) {}
263
264     bool VisitMemberExpr(MemberExpr *E) {
265         if (E->getMemberNameInfo().getAsString() != "transit")
266             return true;
267         if (E->hasExplicitTemplateArgs()) {
268             const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
269             CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
270             CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
271             Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
272             model.transitions.push_back(T);
273         }
274         return true;
275     }
276 };
277
278 class Visitor : public RecursiveASTVisitor<Visitor>
279 {
280     ASTContext *ASTCtx;
281     Model::Model &model;
282     DiagnosticsEngine &Diags;
283     unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
284         diag_found_state, diag_found_statemachine, diag_no_history, diag_missing_reaction, diag_warning;
285
286 public:
287     bool shouldVisitTemplateInstantiations() const { return true; }
288
289     explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
290         : ASTCtx(Context), model(model), Diags(Diags)
291     {
292         diag_found_statemachine =
293             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
294         diag_found_state =
295             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
296         diag_unhandled_reaction_type =
297             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
298         diag_unhandled_reaction_decl =
299             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
300         diag_unhandled_reaction_decl =
301             Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
302         diag_missing_reaction =
303             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Missing react method for event '%0'");
304         diag_warning =
305             Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
306     }
307
308     DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
309
310     bool HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
311     {
312         IdentifierInfo& II = ASTCtx->Idents.get("react");
313         // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
314         for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
315              ReactRes.first != ReactRes.second; ++ReactRes.first) {
316             if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first)) {
317                 if (React->getNumParams() >= 1) {
318                     const ParmVarDecl *p = React->getParamDecl(0);
319                     const Type *ParmType = p->getType().getTypePtr();
320                     if (ParmType->isLValueReferenceType())
321                         ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
322                     if (ParmType == EventType) {
323                         FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
324                         return true;
325                     }
326                 } else
327                     Diag(React->getLocStart(), diag_warning)
328                         << React << "has not a parameter";
329             } else
330                 Diag((*ReactRes.first)->getSourceRange().getBegin(), diag_warning)
331                     << (*ReactRes.first)->getDeclKindName() << "is not supported as react method";
332         }
333         return false;
334     }
335
336     void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
337     {
338         // TODO: Improve Loc tracking
339         if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
340             HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
341         else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
342             string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
343             if (name == "boost::statechart::transition") {
344                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
345                 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
346                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
347                 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
348
349                 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
350                 model.transitions.push_back(T);
351             } else if (name == "boost::statechart::custom_reaction") {
352                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
353                 if(!HandleCustomReaction(SrcState, EventType)) {
354                         Diag(SrcState->getLocation(), diag_missing_reaction) << EventType->getAsCXXRecordDecl()->getName();
355                 }
356             } else if (name == "boost::statechart::deferral") {
357                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
358                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
359
360                 Model::State *s = model.findState(SrcState->getName());
361                 assert(s);
362                 s->addDeferredEvent(Event->getName());
363             } else if (name == "boost::mpl::list") {
364                 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
365                     HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
366             } else
367                 Diag(Loc, diag_unhandled_reaction_type) << name;
368         } else
369             Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
370     }
371
372     void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
373     {
374         if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
375             HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
376                            r->getLocStart(), SrcState);
377         else
378             Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
379     }
380
381     TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum, bool ignore)
382     {
383         if (const ElaboratedTypeLoc *ET = dyn_cast<ElaboratedTypeLoc>(&T))
384             return getTemplateArgLoc(ET->getNamedTypeLoc(), ArgNum, ignore);
385         else if (const TemplateSpecializationTypeLoc *TST = dyn_cast<TemplateSpecializationTypeLoc>(&T)) {
386             if (TST->getNumArgs() >= ArgNum+1) {
387                 return TST->getArgLoc(ArgNum);
388             } else
389                 if (!ignore)
390                     Diag(TST->getBeginLoc(), diag_warning) << TST->getType()->getTypeClassName() << "has not enough arguments" << TST->getSourceRange();
391         } else
392             Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
393         return TemplateArgumentLoc();
394     }
395
396     TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore) {
397         return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), ArgNum, ignore);
398     }
399
400     CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc, bool ignore = false) {
401         Loc = getTemplateArgLocOfBase(Base, ArgNum, ignore);
402         switch (Loc.getArgument().getKind()) {
403         case TemplateArgument::Type:
404             return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
405         case TemplateArgument::Null:
406             // Diag() was already called
407             break;
408         default:
409             Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
410         }
411         return 0;
412     }
413
414     CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore = false) {
415         TemplateArgumentLoc Loc;
416         return getTemplateArgDeclOfBase(Base, ArgNum, Loc, ignore);
417     }
418
419     void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
420     {
421         int typedef_num = 0;
422         string name(RecordDecl->getName()); //getQualifiedNameAsString());
423         Diag(RecordDecl->getLocStart(), diag_found_state) << name;
424
425         Model::State *state;
426         // Either we saw a reference to forward declared state
427         // before, or we create a new state.
428         if (!(state = model.removeFromUndefinedContexts(name)))
429             state = new Model::State(name);
430
431         CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
432         if (Context) {
433             Model::Context *c = model.findContext(Context->getName());
434             if (!c) {
435                 Model::State *s = new Model::State(Context->getName());
436                 model.addUndefinedState(s);
437                 c = s;
438             }
439             c->add(state);
440         }
441         //TODO support more innitial states
442         TemplateArgumentLoc Loc;
443         if (MyCXXRecordDecl *InnerInitialState =
444             static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc, true))) {
445               if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
446                 InnerInitialState->isDerivedFrom("boost::statechart::state_machine")) {
447                   state->setInitialInnerState(InnerInitialState->getName());
448             }
449             else
450                 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getLocStart(), diag_warning)
451                     << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
452         }
453
454 //          if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
455 //              Diag(History->getLocStart(), diag_no_history);
456
457         IdentifierInfo& II = ASTCtx->Idents.get("reactions");
458         // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
459         for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
460              Reactions.first != Reactions.second; ++Reactions.first, typedef_num++)
461             HandleReaction(*Reactions.first, RecordDecl);
462         if(typedef_num == 0) {
463             Diag(RecordDecl->getLocStart(), diag_warning)
464                 << RecordDecl->getName() << "state has no typedef for reactions";
465             state->setNoTypedef();
466         }
467     }
468
469     void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
470     {
471         Model::Machine m(RecordDecl->getName());
472         Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
473
474         if (MyCXXRecordDecl *InitialState =
475             static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
476             m.setInitialState(InitialState->getName());
477         model.add(m);
478     }
479
480     bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
481     {
482         if (!Declaration->isCompleteDefinition())
483             return true;
484         if (Declaration->getQualifiedNameAsString() == "boost::statechart::state" ||
485             Declaration->getQualifiedNameAsString() == "TimedState" ||
486             Declaration->getQualifiedNameAsString() == "TimedSimpleState")
487             return true; // This is an "abstract class" not a real state
488
489         MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
490         const CXXBaseSpecifier *Base;
491
492         if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
493             handleSimpleState(RecordDecl, Base);
494         else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
495             handleStateMachine(RecordDecl, Base);
496         else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
497         {
498             //sc.events.push_back(RecordDecl->getNameAsString());
499         }
500         return true;
501     }
502 };
503
504
505 class VisualizeStatechartConsumer : public clang::ASTConsumer
506 {
507     Model::Model model;
508     Visitor visitor;
509     string destFileName;
510 public:
511     explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
512                                          DiagnosticsEngine &D)
513         : visitor(Context, model, D), destFileName(destFileName) {}
514
515     virtual void HandleTranslationUnit(clang::ASTContext &Context) {
516         visitor.TraverseDecl(Context.getTranslationUnitDecl());
517         model.write_as_dot_file(destFileName);
518     }
519 };
520
521 class VisualizeStatechartAction : public PluginASTAction
522 {
523 protected:
524   ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
525     size_t dot = getCurrentFile().find_last_of('.');
526     std::string dest = getCurrentFile().substr(0, dot);
527     dest.append(".dot");
528     return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
529   }
530
531   bool ParseArgs(const CompilerInstance &CI,
532                  const std::vector<std::string>& args) {
533     for (unsigned i = 0, e = args.size(); i != e; ++i) {
534       llvm::errs() << "Visualizer arg = " << args[i] << "\n";
535
536       // Example error handling.
537       if (args[i] == "-an-error") {
538         DiagnosticsEngine &D = CI.getDiagnostics();
539         unsigned DiagID = D.getCustomDiagID(
540           DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
541         D.Report(DiagID);
542         return false;
543       }
544     }
545     if (args.size() && args[0] == "help")
546       PrintHelp(llvm::errs());
547
548     return true;
549   }
550   void PrintHelp(llvm::raw_ostream& ros) {
551     ros << "Help for Visualize Statechart plugin goes here\n";
552   }
553
554 };
555
556 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");
557
558 // Local Variables:
559 // c-basic-offset: 4
560 // End: