2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Support/raw_os_ostream.h"
32 #include "clang/AST/ASTConsumer.h"
33 #include "clang/AST/ASTContext.h"
34 #include "clang/AST/CXXInheritance.h"
35 #include "clang/AST/RecursiveASTVisitor.h"
36 #include "clang/Frontend/CompilerInstance.h"
37 #include "clang/Frontend/FrontendPluginRegistry.h"
39 using namespace clang;
45 inline int getIndentLevelIdx() {
46 static int i = ios_base::xalloc();
50 ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
51 ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
52 ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
56 class Context : public map<string, State*> {
58 iterator add(State *state);
59 Context *findContext(const string &name);
62 class State : public Context
64 string initialInnerState;
65 list<string> defferedEvents;
69 explicit State(string name) : noTypedef(false), name(name) {}
70 void setInitialInnerState(string name) { initialInnerState = name; }
71 void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
72 void setNoTypedef() { noTypedef = true;}
73 friend ostream& operator<<(ostream& os, const State& s);
77 Context::iterator Context::add(State *state)
79 pair<iterator, bool> ret = insert(value_type(state->name, state));
83 Context *Context::findContext(const string &name)
85 iterator i = find(name), e;
88 for (i = begin(), e = end(); i != e; ++i) {
89 Context *c = i->second->findContext(name);
96 ostream& operator<<(ostream& os, const Context& c);
98 ostream& operator<<(ostream& os, const State& s)
100 string label = s.name;
101 for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
102 label.append("<br />").append(*i).append(" / defer");
103 if (s.noTypedef) os << indent << s.name << " [label=<" << label << ">, color=\"red\"]\n";
104 else os << indent << s.name << " [label=<" << label << ">]\n";
106 os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
107 os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
108 os << indent << "label = \"" << s.name << "\"\n";
109 os << indent << s.initialInnerState << " [peripheries=2]\n";
110 os << static_cast<Context>(s);
111 os << indent_dec << indent << "}\n";
117 ostream& operator<<(ostream& os, const Context& c)
119 for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
129 const string src, dst, event;
130 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
133 ostream& operator<<(ostream& os, const Transition& t)
135 os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
140 class Machine : public Context
143 string initial_state;
146 explicit Machine(string name) : name(name) {}
148 void setInitialState(string name) { initial_state = name; }
150 friend ostream& operator<<(ostream& os, const Machine& m);
153 ostream& operator<<(ostream& os, const Machine& m)
155 os << indent << "subgraph " << m.name << " {\n" << indent_inc;
156 os << indent << m.initial_state << " [peripheries=2]\n";
157 os << static_cast<Context>(m);
158 os << indent_dec << indent << "}\n";
163 class Model : public map<string, Machine>
165 Context undefined; // For forward-declared state classes
167 list< Transition*> transitions;
169 iterator add(const Machine &m)
171 pair<iterator, bool> ret = insert(value_type(m.name, m));
175 void addUndefinedState(State *m)
177 undefined[m->name] = m;
181 Context *findContext(const string &name)
183 Context::iterator ci = undefined.find(name);
184 if (ci != undefined.end())
186 iterator i = find(name), e;
189 for (i = begin(), e = end(); i != e; ++i) {
190 Context *c = i->second.findContext(name);
197 State *findState(const string &name)
199 for (iterator i = begin(), e = end(); i != e; ++i) {
200 Context *c = i->second.findContext(name);
202 return static_cast<State*>(c);
208 State *removeFromUndefinedContexts(const string &name)
210 Context::iterator ci = undefined.find(name);
211 if (ci == undefined.end())
217 void write_as_dot_file(string fn)
219 ofstream f(fn.c_str());
220 f << "digraph statecharts {\n" << indent_inc;
221 for (iterator i = begin(), e = end(); i != e; i++)
223 for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
225 f << indent_dec << "}\n";
231 class MyCXXRecordDecl : public CXXRecordDecl
233 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
237 string qn(static_cast<const char*>(qualName));
238 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
240 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
241 return canon->getQualifiedNameAsString() == qn;
245 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
246 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
247 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
248 if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
251 *Base = Paths.front().back().Base;
256 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
259 const CXXRecordDecl *SrcState;
260 const Type *EventType;
262 explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
263 : model(model), SrcState(SrcState), EventType(EventType) {}
265 bool VisitMemberExpr(MemberExpr *E) {
266 if (E->getMemberNameInfo().getAsString() != "transit")
268 if (E->hasExplicitTemplateArgs()) {
269 const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
270 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
271 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
272 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
273 model.transitions.push_back(T);
279 class Visitor : public RecursiveASTVisitor<Visitor>
283 DiagnosticsEngine &Diags;
284 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
285 diag_found_state, diag_found_statemachine, diag_no_history, diag_missing_reaction, diag_warning;
286 std::vector<bool> reactMethodVector;
289 bool shouldVisitTemplateInstantiations() const { return true; }
291 explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
292 : ASTCtx(Context), model(model), Diags(Diags)
294 diag_found_statemachine =
295 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
297 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
298 diag_unhandled_reaction_type =
299 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
300 diag_unhandled_reaction_decl =
301 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
302 diag_unhandled_reaction_decl =
303 Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
304 diag_missing_reaction =
305 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Missing react method for event '%0'");
307 Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
310 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
312 void checkAllReactMethods(const CXXRecordDecl *SrcState)
315 IdentifierInfo& II = ASTCtx->Idents.get("react");
316 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
317 ReactRes.first != ReactRes.second; ++ReactRes.first) {
318 if (i<reactMethodVector.size()) {
319 if (reactMethodVector[i] == true) {
320 CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first);
321 Diag(React->getParamDecl(0)->getLocStart(), diag_warning)
322 << React->getParamDecl(0)->getType().getAsString() << " missing in typedef";
325 CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first);
326 Diag(React->getParamDecl(0)->getLocStart(), diag_warning)
327 << React->getParamDecl(0)->getType().getAsString() << " missing in typedef";
331 reactMethodVector.clear();
334 bool HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
337 IdentifierInfo& II = ASTCtx->Idents.get("react");
338 // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
339 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
340 ReactRes.first != ReactRes.second; ++ReactRes.first) {
341 if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first)) {
342 if (React->getNumParams() >= 1) {
343 const ParmVarDecl *p = React->getParamDecl(0);
344 const Type *ParmType = p->getType().getTypePtr();
345 if (i == reactMethodVector.size()) reactMethodVector.push_back(false);
346 if (ParmType->isLValueReferenceType())
347 ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
348 if (ParmType == EventType) {
349 FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
350 reactMethodVector[i] = true;
354 Diag(React->getLocStart(), diag_warning)
355 << React << "has not a parameter";
357 Diag((*ReactRes.first)->getSourceRange().getBegin(), diag_warning)
358 << (*ReactRes.first)->getDeclKindName() << "is not supported as react method";
364 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
366 // TODO: Improve Loc tracking
367 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
368 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
369 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
370 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
371 if (name == "boost::statechart::transition") {
372 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
373 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
374 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
375 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
377 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
378 model.transitions.push_back(T);
379 } else if (name == "boost::statechart::custom_reaction") {
380 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
381 if (!HandleCustomReaction(SrcState, EventType)) {
382 Diag(SrcState->getLocation(), diag_missing_reaction) << EventType->getAsCXXRecordDecl()->getName();
384 } else if (name == "boost::statechart::deferral") {
385 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
386 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
388 Model::State *s = model.findState(SrcState->getName());
390 s->addDeferredEvent(Event->getName());
391 } else if (name == "boost::mpl::list") {
392 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
393 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
395 Diag(Loc, diag_unhandled_reaction_type) << name;
397 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
400 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
402 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
403 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
404 r->getLocStart(), SrcState);
406 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
407 checkAllReactMethods(SrcState);
410 TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum, bool ignore)
412 if (const ElaboratedTypeLoc *ET = dyn_cast<ElaboratedTypeLoc>(&T))
413 return getTemplateArgLoc(ET->getNamedTypeLoc(), ArgNum, ignore);
414 else if (const TemplateSpecializationTypeLoc *TST = dyn_cast<TemplateSpecializationTypeLoc>(&T)) {
415 if (TST->getNumArgs() >= ArgNum+1) {
416 return TST->getArgLoc(ArgNum);
419 Diag(TST->getBeginLoc(), diag_warning) << TST->getType()->getTypeClassName() << "has not enough arguments" << TST->getSourceRange();
421 Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
422 return TemplateArgumentLoc();
425 TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore) {
426 return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), ArgNum, ignore);
429 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc, bool ignore = false) {
430 Loc = getTemplateArgLocOfBase(Base, ArgNum, ignore);
431 switch (Loc.getArgument().getKind()) {
432 case TemplateArgument::Type:
433 return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
434 case TemplateArgument::Null:
435 // Diag() was already called
438 Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
443 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore = false) {
444 TemplateArgumentLoc Loc;
445 return getTemplateArgDeclOfBase(Base, ArgNum, Loc, ignore);
448 void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
451 string name(RecordDecl->getName()); //getQualifiedNameAsString());
452 Diag(RecordDecl->getLocStart(), diag_found_state) << name;
455 // Either we saw a reference to forward declared state
456 // before, or we create a new state.
457 if (!(state = model.removeFromUndefinedContexts(name)))
458 state = new Model::State(name);
460 CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
462 Model::Context *c = model.findContext(Context->getName());
464 Model::State *s = new Model::State(Context->getName());
465 model.addUndefinedState(s);
470 //TODO support more innitial states
471 TemplateArgumentLoc Loc;
472 if (MyCXXRecordDecl *InnerInitialState =
473 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc, true))) {
474 if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
475 InnerInitialState->isDerivedFrom("boost::statechart::state_machine")) {
476 state->setInitialInnerState(InnerInitialState->getName());
479 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getLocStart(), diag_warning)
480 << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
483 // if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
484 // Diag(History->getLocStart(), diag_no_history);
486 IdentifierInfo& II = ASTCtx->Idents.get("reactions");
487 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
488 for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
489 Reactions.first != Reactions.second; ++Reactions.first, typedef_num++)
490 HandleReaction(*Reactions.first, RecordDecl);
491 if(typedef_num == 0) {
492 Diag(RecordDecl->getLocStart(), diag_warning)
493 << RecordDecl->getName() << "state has no typedef for reactions";
494 state->setNoTypedef();
498 void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
500 Model::Machine m(RecordDecl->getName());
501 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
503 if (MyCXXRecordDecl *InitialState =
504 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
505 m.setInitialState(InitialState->getName());
509 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
511 if (!Declaration->isCompleteDefinition())
513 if (Declaration->getQualifiedNameAsString() == "boost::statechart::state" ||
514 Declaration->getQualifiedNameAsString() == "TimedState" ||
515 Declaration->getQualifiedNameAsString() == "TimedSimpleState")
516 return true; // This is an "abstract class" not a real state
518 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
519 const CXXBaseSpecifier *Base;
521 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
522 handleSimpleState(RecordDecl, Base);
523 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
524 handleStateMachine(RecordDecl, Base);
525 else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
527 //sc.events.push_back(RecordDecl->getNameAsString());
534 class VisualizeStatechartConsumer : public clang::ASTConsumer
540 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
541 DiagnosticsEngine &D)
542 : visitor(Context, model, D), destFileName(destFileName) {}
544 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
545 visitor.TraverseDecl(Context.getTranslationUnitDecl());
546 model.write_as_dot_file(destFileName);
550 class VisualizeStatechartAction : public PluginASTAction
553 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
554 size_t dot = getCurrentFile().find_last_of('.');
555 std::string dest = getCurrentFile().substr(0, dot);
557 return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
560 bool ParseArgs(const CompilerInstance &CI,
561 const std::vector<std::string>& args) {
562 for (unsigned i = 0, e = args.size(); i != e; ++i) {
563 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
565 // Example error handling.
566 if (args[i] == "-an-error") {
567 DiagnosticsEngine &D = CI.getDiagnostics();
568 unsigned DiagID = D.getCustomDiagID(
569 DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
574 if (args.size() && args[0] == "help")
575 PrintHelp(llvm::errs());
579 void PrintHelp(llvm::raw_ostream& ros) {
580 ros << "Help for Visualize Statechart plugin goes here\n";
585 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");