2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Support/raw_os_ostream.h"
32 #include "clang/AST/ASTConsumer.h"
33 #include "clang/AST/ASTContext.h"
34 #include "clang/AST/CXXInheritance.h"
35 #include "clang/AST/RecursiveASTVisitor.h"
36 #include "clang/Frontend/CompilerInstance.h"
37 #include "clang/Frontend/FrontendPluginRegistry.h"
39 using namespace clang;
45 inline int getIndentLevelIdx() {
46 static int i = ios_base::xalloc();
50 ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
51 ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
52 ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
56 class Context : public map<string, State*> {
58 iterator add(State *state);
59 Context *findContext(const string &name);
62 class State : public Context
64 string initialInnerState;
65 list<string> defferedEvents;
66 list<string> inStateEvents;
70 explicit State(string name) : noTypedef(false), name(name) {}
71 void setInitialInnerState(string name) { initialInnerState = name; }
72 void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
73 void addInStateEvent(const string &name) { inStateEvents.push_back(name); }
74 void setNoTypedef() { noTypedef = true;}
75 friend ostream& operator<<(ostream& os, const State& s);
79 Context::iterator Context::add(State *state)
81 pair<iterator, bool> ret = insert(value_type(state->name, state));
85 Context *Context::findContext(const string &name)
87 iterator i = find(name), e;
90 for (i = begin(), e = end(); i != e; ++i) {
91 Context *c = i->second->findContext(name);
98 ostream& operator<<(ostream& os, const Context& c);
100 ostream& operator<<(ostream& os, const State& s)
102 string label = s.name;
103 for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
104 label.append("<br />").append(*i).append(" / defer");
105 for (list<string>::const_iterator i = s.inStateEvents.begin(), e = s.inStateEvents.end(); i != e; ++i)
106 label.append("<br />").append(*i).append(" / in state");
107 if (s.noTypedef) os << indent << s.name << " [label=<" << label << ">, color=\"red\"]\n";
108 else os << indent << s.name << " [label=<" << label << ">]\n";
110 os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
111 os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
112 os << indent << "label = \"" << s.name << "\"\n";
113 os << indent << s.initialInnerState << " [peripheries=2]\n";
114 os << static_cast<Context>(s);
115 os << indent_dec << indent << "}\n";
121 ostream& operator<<(ostream& os, const Context& c)
123 for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
133 const string src, dst, event;
134 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
137 ostream& operator<<(ostream& os, const Transition& t)
139 os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
144 class Machine : public Context
147 string initial_state;
150 explicit Machine(string name) : name(name) {}
152 void setInitialState(string name) { initial_state = name; }
154 friend ostream& operator<<(ostream& os, const Machine& m);
157 ostream& operator<<(ostream& os, const Machine& m)
159 os << indent << "subgraph " << m.name << " {\n" << indent_inc;
160 os << indent << m.initial_state << " [peripheries=2]\n";
161 os << static_cast<Context>(m);
162 os << indent_dec << indent << "}\n";
167 class Model : public map<string, Machine>
169 Context undefined; // For forward-declared state classes
171 list< Transition*> transitions;
173 iterator add(const Machine &m)
175 pair<iterator, bool> ret = insert(value_type(m.name, m));
179 void addUndefinedState(State *m)
181 undefined[m->name] = m;
185 Context *findContext(const string &name)
187 Context::iterator ci = undefined.find(name);
188 if (ci != undefined.end())
190 iterator i = find(name), e;
193 for (i = begin(), e = end(); i != e; ++i) {
194 Context *c = i->second.findContext(name);
201 State *findState(const string &name)
203 for (iterator i = begin(), e = end(); i != e; ++i) {
204 Context *c = i->second.findContext(name);
206 return static_cast<State*>(c);
212 State *removeFromUndefinedContexts(const string &name)
214 Context::iterator ci = undefined.find(name);
215 if (ci == undefined.end())
221 void write_as_dot_file(string fn)
223 ofstream f(fn.c_str());
224 f << "digraph statecharts {\n" << indent_inc;
225 for (iterator i = begin(), e = end(); i != e; i++)
227 for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
229 f << indent_dec << "}\n";
235 class MyCXXRecordDecl : public CXXRecordDecl
237 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
241 string qn(static_cast<const char*>(qualName));
242 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
244 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
245 return canon->getQualifiedNameAsString() == qn;
249 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
250 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
251 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
252 if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
255 *Base = Paths.front().back().Base;
260 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
263 const CXXRecordDecl *SrcState;
264 const Type *EventType;
266 explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
267 : model(model), SrcState(SrcState), EventType(EventType) {}
269 bool VisitMemberExpr(MemberExpr *E) {
270 if (E->getMemberNameInfo().getAsString() == "defer_event") {
271 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
273 Model::State *s = model.findState(SrcState->getName());
275 s->addDeferredEvent(Event->getName());
276 } else if (E->getMemberNameInfo().getAsString() != "transit")
278 if (E->hasExplicitTemplateArgs()) {
279 const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
280 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
281 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
282 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
283 model.transitions.push_back(T);
289 class Visitor : public RecursiveASTVisitor<Visitor>
294 eventModel(string ev, SourceLocation sourceLoc) : name(ev), loc(sourceLoc){}
296 struct testEventModel {
298 testEventModel(string name) : eventName(name){}
299 bool operator() (const eventModel& model) {
300 if (eventName.compare(model.name) == 0)
307 DiagnosticsEngine &Diags;
308 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
309 diag_found_state, diag_found_statemachine, diag_no_history, diag_missing_reaction, diag_warning;
310 std::vector<bool> reactMethodInReactions; // Indicates whether i-th react method is referenced from typedef reactions.
311 std::list<eventModel> listOfDefinedEvents;
314 bool shouldVisitTemplateInstantiations() const { return true; }
316 explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
317 : ASTCtx(Context), model(model), Diags(Diags)
319 diag_found_statemachine =
320 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
322 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
323 diag_unhandled_reaction_type =
324 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
325 diag_unhandled_reaction_decl =
326 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
328 Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
329 diag_missing_reaction =
330 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Missing react method for event '%0'");
332 Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
335 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
337 void checkAllReactMethods(const CXXRecordDecl *SrcState)
340 IdentifierInfo& II = ASTCtx->Idents.get("react");
341 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
342 ReactRes.first != ReactRes.second; ++ReactRes.first, ++i) {
343 if (i >= reactMethodInReactions.size() || reactMethodInReactions[i] == false) {
344 CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first);
345 Diag(React->getParamDecl(0)->getLocStart(), diag_warning)
346 << React->getParamDecl(0)->getType().getAsString() << " missing in typedef reactions";
351 bool HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
354 IdentifierInfo& II = ASTCtx->Idents.get("react");
355 // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
356 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
357 ReactRes.first != ReactRes.second; ++ReactRes.first) {
358 if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first)) {
359 if (React->getNumParams() >= 1) {
360 const ParmVarDecl *p = React->getParamDecl(0);
361 const Type *ParmType = p->getType().getTypePtr();
362 if (i == reactMethodInReactions.size()) reactMethodInReactions.push_back(false);
363 if (ParmType->isLValueReferenceType())
364 ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
365 if (ParmType == EventType) {
366 FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
367 reactMethodInReactions[i] = true;
371 Diag(React->getLocStart(), diag_warning)
372 << React << "has not a parameter";
374 Diag((*ReactRes.first)->getSourceRange().getBegin(), diag_warning)
375 << (*ReactRes.first)->getDeclKindName() << "is not supported as react method";
381 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
383 // TODO: Improve Loc tracking
384 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
385 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
386 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
387 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
388 if (name == "boost::statechart::transition") {
389 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
390 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
391 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
392 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
393 listOfDefinedEvents.remove_if(testEventModel(Event->getNameAsString()));
395 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
396 model.transitions.push_back(T);
397 } else if (name == "boost::statechart::custom_reaction") {
398 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
399 if (!HandleCustomReaction(SrcState, EventType)) {
400 Diag(SrcState->getLocation(), diag_missing_reaction) << EventType->getAsCXXRecordDecl()->getName();
402 listOfDefinedEvents.remove_if(testEventModel(EventType->getAsCXXRecordDecl()->getNameAsString()));
403 } else if (name == "boost::statechart::deferral") {
404 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
405 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
406 listOfDefinedEvents.remove_if(testEventModel(Event->getNameAsString()));
408 Model::State *s = model.findState(SrcState->getName());
410 s->addDeferredEvent(Event->getName());
411 } else if (name == "boost::mpl::list") {
412 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
413 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
414 } else if (name == "boost::statechart::in_state_reaction") {
415 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
416 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
417 listOfDefinedEvents.remove_if(testEventModel(Event->getNameAsString()));
419 Model::State *s = model.findState(SrcState->getName());
421 s->addInStateEvent(Event->getName());
424 Diag(Loc, diag_unhandled_reaction_type) << name;
426 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
429 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
431 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
432 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
433 r->getLocStart(), SrcState);
435 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
436 checkAllReactMethods(SrcState);
439 TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum, bool ignore)
441 if (const ElaboratedTypeLoc *ET = dyn_cast<ElaboratedTypeLoc>(&T))
442 return getTemplateArgLoc(ET->getNamedTypeLoc(), ArgNum, ignore);
443 else if (const TemplateSpecializationTypeLoc *TST = dyn_cast<TemplateSpecializationTypeLoc>(&T)) {
444 if (TST->getNumArgs() >= ArgNum+1) {
445 return TST->getArgLoc(ArgNum);
448 Diag(TST->getBeginLoc(), diag_warning) << TST->getType()->getTypeClassName() << "has not enough arguments" << TST->getSourceRange();
450 Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
451 return TemplateArgumentLoc();
454 TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore) {
455 return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), ArgNum, ignore);
458 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc, bool ignore = false) {
459 Loc = getTemplateArgLocOfBase(Base, ArgNum, ignore);
460 switch (Loc.getArgument().getKind()) {
461 case TemplateArgument::Type:
462 return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
463 case TemplateArgument::Null:
464 // Diag() was already called
467 Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
472 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore = false) {
473 TemplateArgumentLoc Loc;
474 return getTemplateArgDeclOfBase(Base, ArgNum, Loc, ignore);
477 void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
480 string name(RecordDecl->getName()); //getQualifiedNameAsString());
481 Diag(RecordDecl->getLocStart(), diag_found_state) << name;
482 reactMethodInReactions.clear();
485 // Either we saw a reference to forward declared state
486 // before, or we create a new state.
487 if (!(state = model.removeFromUndefinedContexts(name)))
488 state = new Model::State(name);
490 CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
492 Model::Context *c = model.findContext(Context->getName());
494 Model::State *s = new Model::State(Context->getName());
495 model.addUndefinedState(s);
500 //TODO support more innitial states
501 TemplateArgumentLoc Loc;
502 if (MyCXXRecordDecl *InnerInitialState =
503 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc, true))) {
504 if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
505 InnerInitialState->isDerivedFrom("boost::statechart::state_machine")) {
506 state->setInitialInnerState(InnerInitialState->getName());
508 else if (!InnerInitialState->getNameAsString().compare("boost::mpl::list<>"))
509 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getBeginLoc(), diag_warning)
510 << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
513 // if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
514 // Diag(History->getLocStart(), diag_no_history);
516 IdentifierInfo& II = ASTCtx->Idents.get("reactions");
517 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
518 for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
519 Reactions.first != Reactions.second; ++Reactions.first, typedef_num++)
520 HandleReaction(*Reactions.first, RecordDecl);
521 if(typedef_num == 0) {
522 Diag(RecordDecl->getLocStart(), diag_warning)
523 << RecordDecl->getName() << "state has no typedef for reactions";
524 state->setNoTypedef();
528 void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
530 Model::Machine m(RecordDecl->getName());
531 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
533 if (MyCXXRecordDecl *InitialState =
534 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
535 m.setInitialState(InitialState->getName());
539 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
541 if (!Declaration->isCompleteDefinition())
543 if (Declaration->getQualifiedNameAsString() == "boost::statechart::state" ||
544 Declaration->getQualifiedNameAsString() == "TimedState" ||
545 Declaration->getQualifiedNameAsString() == "TimedSimpleState" ||
546 Declaration->getQualifiedNameAsString() == "boost::statechart::assynchronous_state_machine")
547 return true; // This is an "abstract class" not a real state or real state machine
549 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
550 const CXXBaseSpecifier *Base;
552 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
553 handleSimpleState(RecordDecl, Base);
554 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
555 handleStateMachine(RecordDecl, Base);
556 else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
557 listOfDefinedEvents.push_back(eventModel(RecordDecl->getNameAsString(), RecordDecl->getLocation()));
560 void printUnusedEventDefinitions() {
561 for(list<eventModel>::iterator it = listOfDefinedEvents.begin(); it!=listOfDefinedEvents.end(); it++)
562 Diag((*it).loc, diag_warning)
563 << (*it).name << "event defined but not used in any state";
568 class VisualizeStatechartConsumer : public clang::ASTConsumer
574 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
575 DiagnosticsEngine &D)
576 : visitor(Context, model, D), destFileName(destFileName) {}
578 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
579 visitor.TraverseDecl(Context.getTranslationUnitDecl());
580 visitor.printUnusedEventDefinitions();
581 model.write_as_dot_file(destFileName);
585 class VisualizeStatechartAction : public PluginASTAction
588 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
589 size_t dot = getCurrentFile().find_last_of('.');
590 std::string dest = getCurrentFile().substr(0, dot);
592 return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
595 bool ParseArgs(const CompilerInstance &CI,
596 const std::vector<std::string>& args) {
597 for (unsigned i = 0, e = args.size(); i != e; ++i) {
598 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
600 // Example error handling.
601 if (args[i] == "-an-error") {
602 DiagnosticsEngine &D = CI.getDiagnostics();
603 unsigned DiagID = D.getCustomDiagID(
604 DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
609 if (args.size() && args[0] == "help")
610 PrintHelp(llvm::errs());
614 void PrintHelp(llvm::raw_ostream& ros) {
615 ros << "Help for Visualize Statechart plugin goes here\n";
620 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");