2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Support/raw_os_ostream.h"
32 #include "clang/AST/ASTConsumer.h"
33 #include "clang/AST/ASTContext.h"
34 #include "clang/AST/CXXInheritance.h"
35 #include "clang/AST/RecursiveASTVisitor.h"
36 #include "clang/Frontend/CompilerInstance.h"
37 #include "clang/Frontend/FrontendPluginRegistry.h"
39 using namespace clang;
45 inline int getIndentLevelIdx() {
46 static int i = ios_base::xalloc();
50 ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
51 ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
52 ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
56 class Context : public map<string, State*> {
58 iterator add(State *state);
59 Context *findContext(const string &name);
62 class State : public Context
64 string initialInnerState;
65 list<string> defferedEvents;
66 list<string> inStateEvents;
70 explicit State(string name) : noTypedef(false), name(name) {}
71 void setInitialInnerState(string name) { initialInnerState = name; }
72 void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
73 void addInStateEvent(const string &name) { inStateEvents.push_back(name); }
74 void setNoTypedef() { noTypedef = true;}
75 friend ostream& operator<<(ostream& os, const State& s);
79 Context::iterator Context::add(State *state)
81 pair<iterator, bool> ret = insert(value_type(state->name, state));
85 Context *Context::findContext(const string &name)
87 iterator i = find(name), e;
90 for (i = begin(), e = end(); i != e; ++i) {
91 Context *c = i->second->findContext(name);
98 ostream& operator<<(ostream& os, const Context& c);
100 ostream& operator<<(ostream& os, const State& s)
102 string label = s.name;
103 for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
104 label.append("<br />").append(*i).append(" / defer");
105 for (list<string>::const_iterator i = s.inStateEvents.begin(), e = s.inStateEvents.end(); i != e; ++i)
106 label.append("<br />").append(*i).append(" / in state");
107 if (s.noTypedef) os << indent << s.name << " [label=<" << label << ">, color=\"red\"]\n";
108 else os << indent << s.name << " [label=<" << label << ">]\n";
110 os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
111 os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
112 os << indent << "label = \"" << s.name << "\"\n";
113 os << indent << s.initialInnerState << " [peripheries=2]\n";
114 os << static_cast<Context>(s);
115 os << indent_dec << indent << "}\n";
121 ostream& operator<<(ostream& os, const Context& c)
123 for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
133 const string src, dst, event;
134 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
137 ostream& operator<<(ostream& os, const Transition& t)
139 os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
144 class Machine : public Context
147 string initial_state;
150 explicit Machine(string name) : name(name) {}
152 void setInitialState(string name) { initial_state = name; }
154 friend ostream& operator<<(ostream& os, const Machine& m);
157 ostream& operator<<(ostream& os, const Machine& m)
159 os << indent << "subgraph " << m.name << " {\n" << indent_inc;
160 os << indent << m.initial_state << " [peripheries=2]\n";
161 os << static_cast<Context>(m);
162 os << indent_dec << indent << "}\n";
167 class Model : public map<string, Machine>
169 Context undefined; // For forward-declared state classes
171 list< Transition*> transitions;
173 iterator add(const Machine &m)
175 pair<iterator, bool> ret = insert(value_type(m.name, m));
179 void addUndefinedState(State *m)
181 undefined[m->name] = m;
185 Context *findContext(const string &name)
187 Context::iterator ci = undefined.find(name);
188 if (ci != undefined.end())
190 iterator i = find(name), e;
193 for (i = begin(), e = end(); i != e; ++i) {
194 Context *c = i->second.findContext(name);
201 State *findState(const string &name)
203 for (iterator i = begin(), e = end(); i != e; ++i) {
204 Context *c = i->second.findContext(name);
206 return static_cast<State*>(c);
212 State *removeFromUndefinedContexts(const string &name)
214 Context::iterator ci = undefined.find(name);
215 if (ci == undefined.end())
221 void write_as_dot_file(string fn)
223 ofstream f(fn.c_str());
224 f << "digraph statecharts {\n" << indent_inc;
225 for (iterator i = begin(), e = end(); i != e; i++)
227 for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
229 f << indent_dec << "}\n";
235 class MyCXXRecordDecl : public CXXRecordDecl
237 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
241 string qn(static_cast<const char*>(qualName));
242 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
244 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
245 return canon->getQualifiedNameAsString() == qn;
249 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
250 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
251 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
252 if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
255 *Base = Paths.front().back().Base;
260 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
263 const CXXRecordDecl *SrcState;
264 const Type *EventType;
266 explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
267 : model(model), SrcState(SrcState), EventType(EventType) {}
269 bool VisitMemberExpr(MemberExpr *E) {
270 if (E->getMemberNameInfo().getAsString() == "defer_event") {
271 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
273 Model::State *s = model.findState(SrcState->getName());
275 s->addDeferredEvent(Event->getName());
276 } else if (E->getMemberNameInfo().getAsString() != "transit")
278 if (E->hasExplicitTemplateArgs()) {
279 const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
280 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
281 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
282 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
283 model.transitions.push_back(T);
289 class Visitor : public RecursiveASTVisitor<Visitor>
294 eventModel(string ev, SourceLocation sourceLoc) : name(ev), loc(sourceLoc){}
296 struct testEventModel {
298 testEventModel(string name) : eventName(name){}
299 bool operator() (const eventModel& model) {
300 if (eventName.compare(model.name) == 0)
307 DiagnosticsEngine &Diags;
308 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
309 diag_found_state, diag_found_statemachine, diag_no_history, diag_missing_reaction, diag_warning;
310 std::vector<bool> reactMethodInReactions; // Indicates whether i-th react method is referenced from typedef reactions.
311 std::list<eventModel> listOfDefinedEvents;
314 bool shouldVisitTemplateInstantiations() const { return true; }
316 explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
317 : ASTCtx(Context), model(model), Diags(Diags)
319 diag_found_statemachine =
320 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
322 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
323 diag_unhandled_reaction_type =
324 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
325 diag_unhandled_reaction_decl =
326 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
328 Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
329 diag_missing_reaction =
330 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Missing react method for event '%0'");
332 Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
335 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
337 void initializeDiagnostic()
339 /*Initialization of diagnostics. If this is not done. No notes are printed before first warning or error.*/
340 Diags.Report(SourceLocation(), Diags.getCustomDiagID(DiagnosticsEngine::Warning, "Visualizer plugin is running!\n\n"));
343 void checkAllReactMethods(const CXXRecordDecl *SrcState)
346 IdentifierInfo& II = ASTCtx->Idents.get("react");
347 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
348 ReactRes.first != ReactRes.second; ++ReactRes.first, ++i) {
349 if (i >= reactMethodInReactions.size() || reactMethodInReactions[i] == false) {
350 CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first);
351 Diag(React->getParamDecl(0)->getLocStart(), diag_warning)
352 << React->getParamDecl(0)->getType().getAsString() << " missing in typedef reactions";
357 bool HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
360 IdentifierInfo& II = ASTCtx->Idents.get("react");
361 // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
362 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
363 ReactRes.first != ReactRes.second; ++ReactRes.first) {
364 if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first)) {
365 if (React->getNumParams() >= 1) {
366 const ParmVarDecl *p = React->getParamDecl(0);
367 const Type *ParmType = p->getType().getTypePtr();
368 if (i == reactMethodInReactions.size()) reactMethodInReactions.push_back(false);
369 if (ParmType->isLValueReferenceType())
370 ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
371 if (ParmType == EventType) {
372 FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
373 reactMethodInReactions[i] = true;
377 Diag(React->getLocStart(), diag_warning)
378 << React << "has not a parameter";
380 Diag((*ReactRes.first)->getSourceRange().getBegin(), diag_warning)
381 << (*ReactRes.first)->getDeclKindName() << "is not supported as react method";
387 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
389 // TODO: Improve Loc tracking
390 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
391 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
392 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
393 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
394 if (name == "boost::statechart::transition") {
395 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
396 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
397 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
398 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
399 listOfDefinedEvents.remove_if(testEventModel(Event->getNameAsString()));
401 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
402 model.transitions.push_back(T);
403 } else if (name == "boost::statechart::custom_reaction") {
404 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
405 if (!HandleCustomReaction(SrcState, EventType)) {
406 Diag(SrcState->getLocation(), diag_missing_reaction) << EventType->getAsCXXRecordDecl()->getName();
408 listOfDefinedEvents.remove_if(testEventModel(EventType->getAsCXXRecordDecl()->getNameAsString()));
409 } else if (name == "boost::statechart::deferral") {
410 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
411 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
412 listOfDefinedEvents.remove_if(testEventModel(Event->getNameAsString()));
414 Model::State *s = model.findState(SrcState->getName());
416 s->addDeferredEvent(Event->getName());
417 } else if (name == "boost::mpl::list") {
418 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
419 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
420 } else if (name == "boost::statechart::in_state_reaction") {
421 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
422 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
423 listOfDefinedEvents.remove_if(testEventModel(Event->getNameAsString()));
425 Model::State *s = model.findState(SrcState->getName());
427 s->addInStateEvent(Event->getName());
430 Diag(Loc, diag_unhandled_reaction_type) << name;
432 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
435 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
437 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
438 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
439 r->getLocStart(), SrcState);
441 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
442 checkAllReactMethods(SrcState);
445 TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum, bool ignore)
447 if (const ElaboratedTypeLoc *ET = dyn_cast<ElaboratedTypeLoc>(&T))
448 return getTemplateArgLoc(ET->getNamedTypeLoc(), ArgNum, ignore);
449 else if (const TemplateSpecializationTypeLoc *TST = dyn_cast<TemplateSpecializationTypeLoc>(&T)) {
450 if (TST->getNumArgs() >= ArgNum+1) {
451 return TST->getArgLoc(ArgNum);
454 Diag(TST->getBeginLoc(), diag_warning) << TST->getType()->getTypeClassName() << "has not enough arguments" << TST->getSourceRange();
456 Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
457 return TemplateArgumentLoc();
460 TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore) {
461 return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), ArgNum, ignore);
464 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc, bool ignore = false) {
465 Loc = getTemplateArgLocOfBase(Base, ArgNum, ignore);
466 switch (Loc.getArgument().getKind()) {
467 case TemplateArgument::Type:
468 return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
469 case TemplateArgument::Null:
470 // Diag() was already called
473 Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
478 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore = false) {
479 TemplateArgumentLoc Loc;
480 return getTemplateArgDeclOfBase(Base, ArgNum, Loc, ignore);
483 void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
486 string name(RecordDecl->getName()); //getQualifiedNameAsString());
487 Diag(RecordDecl->getLocStart(), diag_found_state) << name;
488 reactMethodInReactions.clear();
491 // Either we saw a reference to forward declared state
492 // before, or we create a new state.
493 if (!(state = model.removeFromUndefinedContexts(name)))
494 state = new Model::State(name);
496 CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
498 Model::Context *c = model.findContext(Context->getName());
500 Model::State *s = new Model::State(Context->getName());
501 model.addUndefinedState(s);
506 //TODO support more innitial states
507 TemplateArgumentLoc Loc;
508 if (MyCXXRecordDecl *InnerInitialState =
509 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc, true))) {
510 if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
511 InnerInitialState->isDerivedFrom("boost::statechart::state_machine")) {
512 state->setInitialInnerState(InnerInitialState->getName());
514 else if (!InnerInitialState->getNameAsString().compare("boost::mpl::list<>"))
515 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getBeginLoc(), diag_warning)
516 << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
519 // if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
520 // Diag(History->getLocStart(), diag_no_history);
522 IdentifierInfo& II = ASTCtx->Idents.get("reactions");
523 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
524 for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
525 Reactions.first != Reactions.second; ++Reactions.first, typedef_num++)
526 HandleReaction(*Reactions.first, RecordDecl);
527 if(typedef_num == 0) {
528 Diag(RecordDecl->getLocStart(), diag_warning)
529 << RecordDecl->getName() << "state has no typedef for reactions";
530 state->setNoTypedef();
534 void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
536 Model::Machine m(RecordDecl->getName());
537 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
539 if (MyCXXRecordDecl *InitialState =
540 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
541 m.setInitialState(InitialState->getName());
545 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
547 if (!Declaration->isCompleteDefinition())
549 if (Declaration->getQualifiedNameAsString() == "boost::statechart::state" ||
550 Declaration->getQualifiedNameAsString() == "TimedState" ||
551 Declaration->getQualifiedNameAsString() == "TimedSimpleState" ||
552 Declaration->getQualifiedNameAsString() == "boost::statechart::assynchronous_state_machine")
553 return true; // This is an "abstract class" not a real state or real state machine
555 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
556 const CXXBaseSpecifier *Base;
558 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
559 handleSimpleState(RecordDecl, Base);
560 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
561 handleStateMachine(RecordDecl, Base);
562 else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
563 listOfDefinedEvents.push_back(eventModel(RecordDecl->getNameAsString(), RecordDecl->getLocation()));
566 void printUnusedEventDefinitions() {
567 for(list<eventModel>::iterator it = listOfDefinedEvents.begin(); it!=listOfDefinedEvents.end(); it++)
568 Diag((*it).loc, diag_warning)
569 << (*it).name << "event defined but not used in any state";
574 class VisualizeStatechartConsumer : public clang::ASTConsumer
580 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
581 DiagnosticsEngine &D)
582 : visitor(Context, model, D), destFileName(destFileName) {}
584 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
585 visitor.initializeDiagnostic();
586 visitor.TraverseDecl(Context.getTranslationUnitDecl());
587 visitor.printUnusedEventDefinitions();
588 model.write_as_dot_file(destFileName);
592 class VisualizeStatechartAction : public PluginASTAction
595 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
596 size_t dot = getCurrentFile().find_last_of('.');
597 std::string dest = getCurrentFile().substr(0, dot);
599 return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
602 bool ParseArgs(const CompilerInstance &CI,
603 const std::vector<std::string>& args) {
604 for (unsigned i = 0, e = args.size(); i != e; ++i) {
605 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
607 // Example error handling.
608 if (args[i] == "-an-error") {
609 DiagnosticsEngine &D = CI.getDiagnostics();
610 unsigned DiagID = D.getCustomDiagID(
611 DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
616 if (args.size() && args[0] == "help")
617 PrintHelp(llvm::errs());
621 void PrintHelp(llvm::raw_ostream& ros) {
622 ros << "Help for Visualize Statechart plugin goes here\n";
627 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");