2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Support/raw_os_ostream.h"
32 #include "clang/AST/ASTConsumer.h"
33 #include "clang/AST/ASTContext.h"
34 #include "clang/AST/CXXInheritance.h"
35 #include "clang/AST/RecursiveASTVisitor.h"
36 #include "clang/Frontend/CompilerInstance.h"
37 #include "clang/Frontend/FrontendPluginRegistry.h"
39 using namespace clang;
45 inline int getIndentLevelIdx() {
46 static int i = ios_base::xalloc();
50 ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
51 ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
52 ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
56 class Context : public map<string, State*> {
58 iterator add(State *state);
59 Context *findContext(const string &name);
62 class State : public Context
64 string initialInnerState;
65 list<string> defferedEvents;
66 list<string> inStateEvents;
70 explicit State(string name) : noTypedef(false), name(name) {}
71 void setInitialInnerState(string name) { initialInnerState = name; }
72 void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
73 void addInStateEvent(const string &name) { inStateEvents.push_back(name); }
74 void setNoTypedef() { noTypedef = true;}
75 friend ostream& operator<<(ostream& os, const State& s);
79 Context::iterator Context::add(State *state)
81 pair<iterator, bool> ret = insert(value_type(state->name, state));
85 Context *Context::findContext(const string &name)
87 iterator i = find(name), e;
90 for (i = begin(), e = end(); i != e; ++i) {
91 Context *c = i->second->findContext(name);
98 ostream& operator<<(ostream& os, const Context& c);
100 ostream& operator<<(ostream& os, const State& s)
102 string label = s.name;
103 for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
104 label.append("<br />").append(*i).append(" / defer");
105 for (list<string>::const_iterator i = s.inStateEvents.begin(), e = s.inStateEvents.end(); i != e; ++i)
106 label.append("<br />").append(*i).append(" / in state");
107 if (s.noTypedef) os << indent << s.name << " [label=<" << label << ">, color=\"red\"]\n";
108 else os << indent << s.name << " [label=<" << label << ">]\n";
110 os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
111 os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
112 os << indent << "label = \"" << s.name << "\"\n";
113 os << indent << s.initialInnerState << " [peripheries=2]\n";
114 os << static_cast<Context>(s);
115 os << indent_dec << indent << "}\n";
121 ostream& operator<<(ostream& os, const Context& c)
123 for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
133 const string src, dst, event;
134 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
137 ostream& operator<<(ostream& os, const Transition& t)
139 os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
144 class Machine : public Context
147 string initial_state;
150 explicit Machine(string name) : name(name) {}
152 void setInitialState(string name) { initial_state = name; }
154 friend ostream& operator<<(ostream& os, const Machine& m);
157 ostream& operator<<(ostream& os, const Machine& m)
159 os << indent << "subgraph " << m.name << " {\n" << indent_inc;
160 os << indent << m.initial_state << " [peripheries=2]\n";
161 os << static_cast<Context>(m);
162 os << indent_dec << indent << "}\n";
167 class Model : public map<string, Machine>
169 Context undefined; // For forward-declared state classes
171 list< Transition*> transitions;
173 iterator add(const Machine &m)
175 pair<iterator, bool> ret = insert(value_type(m.name, m));
179 void addUndefinedState(State *m)
181 undefined[m->name] = m;
185 Context *findContext(const string &name)
187 Context::iterator ci = undefined.find(name);
188 if (ci != undefined.end())
190 iterator i = find(name), e;
193 for (i = begin(), e = end(); i != e; ++i) {
194 Context *c = i->second.findContext(name);
201 State *findState(const string &name)
203 for (iterator i = begin(), e = end(); i != e; ++i) {
204 Context *c = i->second.findContext(name);
206 return static_cast<State*>(c);
212 State *removeFromUndefinedContexts(const string &name)
214 Context::iterator ci = undefined.find(name);
215 if (ci == undefined.end())
221 void write_as_dot_file(string fn)
223 ofstream f(fn.c_str());
224 f << "digraph statecharts {\n" << indent_inc;
225 for (iterator i = begin(), e = end(); i != e; i++)
227 for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
229 f << indent_dec << "}\n";
235 class MyCXXRecordDecl : public CXXRecordDecl
237 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
241 string qn(static_cast<const char*>(qualName));
242 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
244 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
245 return canon->getQualifiedNameAsString() == qn;
249 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
250 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
251 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
254 [qn](const CXXBaseSpecifier *Specifier, CXXBasePath &Path) -> bool {
255 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
257 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
258 return canon->getQualifiedNameAsString() == qn;
265 *Base = Paths.front().back().Base;
270 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
273 const CXXRecordDecl *SrcState;
274 const Type *EventType;
276 explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
277 : model(model), SrcState(SrcState), EventType(EventType) {}
279 bool VisitMemberExpr(MemberExpr *E) {
280 if (E->getMemberNameInfo().getAsString() == "defer_event") {
281 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
283 Model::State *s = model.findState(SrcState->getName());
285 s->addDeferredEvent(Event->getName());
286 } else if (E->getMemberNameInfo().getAsString() != "transit")
288 if (E->hasExplicitTemplateArgs()) {
289 const Type *DstStateType = E->getTemplateArgs()[0].getArgument().getAsType().getTypePtr();
290 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
291 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
292 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
293 model.transitions.push_back(T);
299 class Visitor : public RecursiveASTVisitor<Visitor>
304 eventModel(string ev, SourceLocation sourceLoc) : name(ev), loc(sourceLoc){}
307 struct eventHasName {
309 eventHasName(string name) : eventName(name){}
310 bool operator() (const eventModel& model) { return (eventName.compare(model.name) == 0); }
314 DiagnosticsEngine &Diags;
315 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
316 diag_found_state, diag_found_statemachine, diag_no_history, diag_missing_reaction, diag_warning;
317 std::vector<bool> reactMethodInReactions; // Indicates whether i-th react method is referenced from typedef reactions.
318 std::list<eventModel> unusedEvents;
321 bool shouldVisitTemplateInstantiations() const { return true; }
323 explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
324 : ASTCtx(Context), model(model), Diags(Diags)
326 diag_found_statemachine =
327 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
329 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
330 diag_unhandled_reaction_type =
331 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
332 diag_unhandled_reaction_decl =
333 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
335 Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
336 diag_missing_reaction =
337 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Missing react method for event '%0'");
339 Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
342 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
344 void checkAllReactMethods(const CXXRecordDecl *SrcState)
347 IdentifierInfo& II = ASTCtx->Idents.get("react");
348 auto ReactRes = SrcState->lookup(DeclarationName(&II));
349 for (auto it = ReactRes.begin(), end=ReactRes.end(); it != end; ++it, ++i) {
350 if (i >= reactMethodInReactions.size() || reactMethodInReactions[i] == false) {
351 CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*it);
352 Diag(React->getParamDecl(0)->getLocStart(), diag_warning)
353 << React->getParamDecl(0)->getType().getAsString() << " missing in typedef reactions";
358 bool HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
361 IdentifierInfo& II = ASTCtx->Idents.get("react");
362 // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
363 auto ReactRes = SrcState->lookup(DeclarationName(&II));
364 for (auto it = ReactRes.begin(), end=ReactRes.end(); it != end; ++it) {
365 if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*it)) {
366 if (React->getNumParams() >= 1) {
367 const ParmVarDecl *p = React->getParamDecl(0);
368 const Type *ParmType = p->getType().getTypePtr();
369 if (i == reactMethodInReactions.size()) reactMethodInReactions.push_back(false);
370 if (ParmType->isLValueReferenceType())
371 ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
372 if (ParmType == EventType) {
373 FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
374 reactMethodInReactions[i] = true;
378 Diag(React->getLocStart(), diag_warning)
379 << React << "has not a parameter";
381 Diag((*it)->getSourceRange().getBegin(), diag_warning)
382 << (*it)->getDeclKindName() << "is not supported as react method";
388 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
390 // TODO: Improve Loc tracking
391 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
392 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
393 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
394 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
395 if (name == "boost::statechart::transition") {
396 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
397 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
398 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
399 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
400 unusedEvents.remove_if(eventHasName(Event->getNameAsString()));
402 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
403 model.transitions.push_back(T);
404 } else if (name == "boost::statechart::custom_reaction") {
405 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
406 if (!HandleCustomReaction(SrcState, EventType)) {
407 Diag(SrcState->getLocation(), diag_missing_reaction) << EventType->getAsCXXRecordDecl()->getName();
409 unusedEvents.remove_if(eventHasName(EventType->getAsCXXRecordDecl()->getNameAsString()));
410 } else if (name == "boost::statechart::deferral") {
411 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
412 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
413 unusedEvents.remove_if(eventHasName(Event->getNameAsString()));
415 Model::State *s = model.findState(SrcState->getName());
417 s->addDeferredEvent(Event->getName());
418 } else if (name == "boost::mpl::list") {
419 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
420 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
421 } else if (name == "boost::statechart::in_state_reaction") {
422 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
423 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
424 unusedEvents.remove_if(eventHasName(Event->getNameAsString()));
426 Model::State *s = model.findState(SrcState->getName());
428 s->addInStateEvent(Event->getName());
431 Diag(Loc, diag_unhandled_reaction_type) << name;
433 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
436 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
438 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
439 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
440 r->getLocStart(), SrcState);
442 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
443 checkAllReactMethods(SrcState);
446 TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum, bool ignore)
448 if (const ElaboratedTypeLoc ET = T.getAs<ElaboratedTypeLoc>())
449 return getTemplateArgLoc(ET.getNamedTypeLoc(), ArgNum, ignore);
450 else if (const TemplateSpecializationTypeLoc TST = T.getAs<TemplateSpecializationTypeLoc>()) {
451 if (TST.getNumArgs() >= ArgNum+1) {
452 return TST.getArgLoc(ArgNum);
455 Diag(TST.getBeginLoc(), diag_warning) << TST.getType()->getTypeClassName() << "has not enough arguments" << TST.getSourceRange();
457 Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
458 return TemplateArgumentLoc();
461 TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore) {
462 return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), ArgNum, ignore);
465 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc, bool ignore = false) {
466 Loc = getTemplateArgLocOfBase(Base, ArgNum, ignore);
467 switch (Loc.getArgument().getKind()) {
468 case TemplateArgument::Type:
469 return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
470 case TemplateArgument::Null:
471 // Diag() was already called
474 Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
479 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore = false) {
480 TemplateArgumentLoc Loc;
481 return getTemplateArgDeclOfBase(Base, ArgNum, Loc, ignore);
484 void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
487 string name(RecordDecl->getName()); //getQualifiedNameAsString());
488 Diag(RecordDecl->getLocStart(), diag_found_state) << name;
489 reactMethodInReactions.clear();
492 // Either we saw a reference to forward declared state
493 // before, or we create a new state.
494 if (!(state = model.removeFromUndefinedContexts(name)))
495 state = new Model::State(name);
497 CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
499 Model::Context *c = model.findContext(Context->getName());
501 Model::State *s = new Model::State(Context->getName());
502 model.addUndefinedState(s);
507 //TODO support more innitial states
508 TemplateArgumentLoc Loc;
509 if (MyCXXRecordDecl *InnerInitialState =
510 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc, true))) {
511 if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
512 InnerInitialState->isDerivedFrom("boost::statechart::state_machine")) {
513 state->setInitialInnerState(InnerInitialState->getName());
515 else if (!InnerInitialState->getNameAsString().compare("boost::mpl::list<>"))
516 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getBeginLoc(), diag_warning)
517 << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
520 // if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
521 // Diag(History->getLocStart(), diag_no_history);
523 IdentifierInfo& II = ASTCtx->Idents.get("reactions");
524 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
525 auto Reactions = RecordDecl->lookup(DeclarationName(&II));
526 for (auto it = Reactions.begin(), end = Reactions.end(); it != end; ++it, typedef_num++)
527 HandleReaction(*it, RecordDecl);
528 if(typedef_num == 0) {
529 Diag(RecordDecl->getLocStart(), diag_warning)
530 << RecordDecl->getName() << "state has no typedef for reactions";
531 state->setNoTypedef();
535 void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
537 Model::Machine m(RecordDecl->getName());
538 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
540 if (MyCXXRecordDecl *InitialState =
541 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
542 m.setInitialState(InitialState->getName());
546 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
548 if (!Declaration->isCompleteDefinition())
550 if (Declaration->getQualifiedNameAsString() == "boost::statechart::state" ||
551 Declaration->getQualifiedNameAsString() == "TimedState" ||
552 Declaration->getQualifiedNameAsString() == "TimedSimpleState" ||
553 Declaration->getQualifiedNameAsString() == "boost::statechart::assynchronous_state_machine")
554 return true; // This is an "abstract class" not a real state or real state machine
556 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
557 const CXXBaseSpecifier *Base;
559 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
560 handleSimpleState(RecordDecl, Base);
561 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
562 handleStateMachine(RecordDecl, Base);
563 else if (RecordDecl->isDerivedFrom("boost::statechart::event")) {
564 // Mark the event as unused until we found that somebody uses it
565 unusedEvents.push_back(eventModel(RecordDecl->getNameAsString(), RecordDecl->getLocation()));
569 void printUnusedEventDefinitions() {
570 for(list<eventModel>::iterator it = unusedEvents.begin(); it!=unusedEvents.end(); it++)
571 Diag((*it).loc, diag_warning)
572 << (*it).name << "event defined but not used in any state";
577 class VisualizeStatechartConsumer : public clang::ASTConsumer
583 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
584 DiagnosticsEngine &D)
585 : visitor(Context, model, D), destFileName(destFileName) {}
587 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
588 visitor.TraverseDecl(Context.getTranslationUnitDecl());
589 visitor.printUnusedEventDefinitions();
590 model.write_as_dot_file(destFileName);
594 class VisualizeStatechartAction : public PluginASTAction
597 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
598 size_t dot = getCurrentFile().find_last_of('.');
599 std::string dest = getCurrentFile().substr(0, dot);
601 return std::unique_ptr<ASTConsumer>( new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics()) );
604 bool ParseArgs(const CompilerInstance &CI,
605 const std::vector<std::string>& args) {
606 for (unsigned i = 0, e = args.size(); i != e; ++i) {
607 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
609 // Example error handling.
610 if (args[i] == "-an-error") {
611 DiagnosticsEngine &D = CI.getDiagnostics();
612 unsigned DiagID = D.getCustomDiagID(
613 DiagnosticsEngine::Error, "invalid argument '%0' expected '%1'");
618 if (args.size() && args[0] == "help")
619 PrintHelp(llvm::errs());
623 void PrintHelp(llvm::raw_ostream& ros) {
624 ros << "Help for Visualize Statechart plugin goes here\n";
629 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");