2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/Support/raw_os_ostream.h"
31 #include "clang/AST/ASTConsumer.h"
32 #include "clang/AST/ASTContext.h"
33 #include "clang/AST/CXXInheritance.h"
34 #include "clang/AST/RecursiveASTVisitor.h"
35 #include "clang/Frontend/CompilerInstance.h"
36 #include "clang/Frontend/FrontendPluginRegistry.h"
38 using namespace clang;
44 inline int getIndentLevelIdx() {
45 static int i = ios_base::xalloc();
49 ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
50 ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
51 ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
55 class Context : public map<string, State*> {
57 iterator add(State *state);
58 Context *findContext(const string &name);
61 class State : public Context
63 string initialInnerState;
64 list<string> defferedEvents;
68 explicit State(string name) : noTypedef(false), name(name) {}
69 void setInitialInnerState(string name) { initialInnerState = name; }
70 void addDeferredEvent(const string &name) { defferedEvents.push_back(name); }
71 void setNoTypedef() { noTypedef = true;}
72 friend ostream& operator<<(ostream& os, const State& s);
76 Context::iterator Context::add(State *state)
78 pair<iterator, bool> ret = insert(value_type(state->name, state));
82 Context *Context::findContext(const string &name)
84 iterator i = find(name), e;
87 for (i = begin(), e = end(); i != e; ++i) {
88 Context *c = i->second->findContext(name);
95 ostream& operator<<(ostream& os, const Context& c);
97 ostream& operator<<(ostream& os, const State& s)
99 string label = s.name;
100 for (list<string>::const_iterator i = s.defferedEvents.begin(), e = s.defferedEvents.end(); i != e; ++i)
101 label.append("<br />").append(*i).append(" / defer");
102 if (s.noTypedef) os << indent << s.name << " [label=<" << label << ">, color=\"red\"]\n";
103 else os << indent << s.name << " [label=<" << label << ">]\n";
105 os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
106 os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
107 os << indent << "label = \"" << s.name << "\"\n";
108 os << indent << s.initialInnerState << " [peripheries=2]\n";
109 os << static_cast<Context>(s);
110 os << indent_dec << indent << "}\n";
116 ostream& operator<<(ostream& os, const Context& c)
118 for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
128 const string src, dst, event;
129 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
132 ostream& operator<<(ostream& os, const Transition& t)
134 os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
139 class Machine : public Context
142 string initial_state;
145 explicit Machine(string name) : name(name) {}
147 void setInitialState(string name) { initial_state = name; }
149 friend ostream& operator<<(ostream& os, const Machine& m);
152 ostream& operator<<(ostream& os, const Machine& m)
154 os << indent << "subgraph " << m.name << " {\n" << indent_inc;
155 os << indent << m.initial_state << " [peripheries=2]\n";
156 os << static_cast<Context>(m);
157 os << indent_dec << indent << "}\n";
162 class Model : public map<string, Machine>
164 Context undefined; // For forward-declared state classes
166 list< Transition*> transitions;
168 iterator add(const Machine &m)
170 pair<iterator, bool> ret = insert(value_type(m.name, m));
174 void addUndefinedState(State *m)
176 undefined[m->name] = m;
180 Context *findContext(const string &name)
182 Context::iterator ci = undefined.find(name);
183 if (ci != undefined.end())
185 iterator i = find(name), e;
188 for (i = begin(), e = end(); i != e; ++i) {
189 Context *c = i->second.findContext(name);
196 State *findState(const string &name)
198 for (iterator i = begin(), e = end(); i != e; ++i) {
199 Context *c = i->second.findContext(name);
201 return static_cast<State*>(c);
207 State *removeFromUndefinedContexts(const string &name)
209 Context::iterator ci = undefined.find(name);
210 if (ci == undefined.end())
216 void write_as_dot_file(string fn)
218 ofstream f(fn.c_str());
219 f << "digraph statecharts {\n" << indent_inc;
220 for (iterator i = begin(), e = end(); i != e; i++)
222 for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
224 f << indent_dec << "}\n";
230 class MyCXXRecordDecl : public CXXRecordDecl
232 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
236 string qn(static_cast<const char*>(qualName));
237 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
239 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
240 return canon->getQualifiedNameAsString() == qn;
244 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
245 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
246 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
247 if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
250 *Base = Paths.front().back().Base;
255 class FindTransitVisitor : public RecursiveASTVisitor<FindTransitVisitor>
258 const CXXRecordDecl *SrcState;
259 const Type *EventType;
261 explicit FindTransitVisitor(Model::Model &model, const CXXRecordDecl *SrcState, const Type *EventType)
262 : model(model), SrcState(SrcState), EventType(EventType) {}
264 bool VisitMemberExpr(MemberExpr *E) {
265 if (E->getMemberNameInfo().getAsString() != "transit")
267 if (E->hasExplicitTemplateArgs()) {
268 const Type *DstStateType = E->getExplicitTemplateArgs()[0].getArgument().getAsType().getTypePtr();
269 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
270 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
271 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
272 model.transitions.push_back(T);
278 class Visitor : public RecursiveASTVisitor<Visitor>
282 DiagnosticsEngine &Diags;
283 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
284 diag_found_state, diag_found_statemachine, diag_no_history, diag_missing_reaction, diag_warning;
287 bool shouldVisitTemplateInstantiations() const { return true; }
289 explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
290 : ASTCtx(Context), model(model), Diags(Diags)
292 diag_found_statemachine =
293 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
295 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
296 diag_unhandled_reaction_type =
297 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
298 diag_unhandled_reaction_decl =
299 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
300 diag_unhandled_reaction_decl =
301 Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
302 diag_missing_reaction =
303 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Missing react method for event '%0'");
305 Diags.getCustomDiagID(DiagnosticsEngine::Warning, "'%0' %1");
308 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
310 bool HandleCustomReaction(const CXXRecordDecl *SrcState, const Type *EventType)
312 IdentifierInfo& II = ASTCtx->Idents.get("react");
313 // TODO: Lookup for react even in base classes - probably by using Sema::LookupQualifiedName()
314 for (DeclContext::lookup_const_result ReactRes = SrcState->lookup(DeclarationName(&II));
315 ReactRes.first != ReactRes.second; ++ReactRes.first) {
316 if (CXXMethodDecl *React = dyn_cast<CXXMethodDecl>(*ReactRes.first)) {
317 if (React->getNumParams() >= 1) {
318 const ParmVarDecl *p = React->getParamDecl(0);
319 const Type *ParmType = p->getType().getTypePtr();
320 if (ParmType->isLValueReferenceType())
321 ParmType = dyn_cast<LValueReferenceType>(ParmType)->getPointeeType().getTypePtr();
322 if (ParmType == EventType) {
323 FindTransitVisitor(model, SrcState, EventType).TraverseStmt(React->getBody());
327 Diag(React->getLocStart(), diag_warning)
328 << React << "has not a parameter";
330 Diag((*ReactRes.first)->getSourceRange().getBegin(), diag_warning)
331 << (*ReactRes.first)->getDeclKindName() << "is not supported as react method";
336 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
338 // TODO: Improve Loc tracking
339 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
340 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
341 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
342 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
343 if (name == "boost::statechart::transition") {
344 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
345 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
346 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
347 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
349 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
350 model.transitions.push_back(T);
351 } else if (name == "boost::statechart::custom_reaction") {
352 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
353 if(!HandleCustomReaction(SrcState, EventType)) {
354 Diag(SrcState->getLocation(), diag_missing_reaction) << EventType->getAsCXXRecordDecl()->getName();
356 } else if (name == "boost::statechart::deferral") {
357 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
358 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
360 Model::State *s = model.findState(SrcState->getName());
362 s->addDeferredEvent(Event->getName());
363 } else if (name == "boost::mpl::list") {
364 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
365 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
367 Diag(Loc, diag_unhandled_reaction_type) << name;
369 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
372 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
374 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
375 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
376 r->getLocStart(), SrcState);
378 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
381 TemplateArgumentLoc getTemplateArgLoc(const TypeLoc &T, unsigned ArgNum, bool ignore)
383 if (const ElaboratedTypeLoc *ET = dyn_cast<ElaboratedTypeLoc>(&T))
384 return getTemplateArgLoc(ET->getNamedTypeLoc(), ArgNum, ignore);
385 else if (const TemplateSpecializationTypeLoc *TST = dyn_cast<TemplateSpecializationTypeLoc>(&T)) {
386 if (TST->getNumArgs() >= ArgNum+1) {
387 return TST->getArgLoc(ArgNum);
390 Diag(TST->getBeginLoc(), diag_warning) << TST->getType()->getTypeClassName() << "has not enough arguments" << TST->getSourceRange();
392 Diag(T.getBeginLoc(), diag_warning) << T.getType()->getTypeClassName() << "type as template argument is not supported" << T.getSourceRange();
393 return TemplateArgumentLoc();
396 TemplateArgumentLoc getTemplateArgLocOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore) {
397 return getTemplateArgLoc(Base->getTypeSourceInfo()->getTypeLoc(), ArgNum, ignore);
400 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, TemplateArgumentLoc &Loc, bool ignore = false) {
401 Loc = getTemplateArgLocOfBase(Base, ArgNum, ignore);
402 switch (Loc.getArgument().getKind()) {
403 case TemplateArgument::Type:
404 return Loc.getTypeSourceInfo()->getType()->getAsCXXRecordDecl();
405 case TemplateArgument::Null:
406 // Diag() was already called
409 Diag(Loc.getSourceRange().getBegin(), diag_warning) << Loc.getArgument().getKind() << "unsupported kind" << Loc.getSourceRange();
414 CXXRecordDecl *getTemplateArgDeclOfBase(const CXXBaseSpecifier *Base, unsigned ArgNum, bool ignore = false) {
415 TemplateArgumentLoc Loc;
416 return getTemplateArgDeclOfBase(Base, ArgNum, Loc, ignore);
419 void handleSimpleState(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
422 string name(RecordDecl->getName()); //getQualifiedNameAsString());
423 Diag(RecordDecl->getLocStart(), diag_found_state) << name;
426 // Either we saw a reference to forward declared state
427 // before, or we create a new state.
428 if (!(state = model.removeFromUndefinedContexts(name)))
429 state = new Model::State(name);
431 CXXRecordDecl *Context = getTemplateArgDeclOfBase(Base, 1);
433 Model::Context *c = model.findContext(Context->getName());
435 Model::State *s = new Model::State(Context->getName());
436 model.addUndefinedState(s);
441 //TODO support more innitial states
442 TemplateArgumentLoc Loc;
443 if (MyCXXRecordDecl *InnerInitialState =
444 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 2, Loc, true))) {
445 if (InnerInitialState->isDerivedFrom("boost::statechart::simple_state") ||
446 InnerInitialState->isDerivedFrom("boost::statechart::state_machine")) {
447 state->setInitialInnerState(InnerInitialState->getName());
450 Diag(Loc.getTypeSourceInfo()->getTypeLoc().getLocStart(), diag_warning)
451 << InnerInitialState->getName() << " as inner initial state is not supported" << Loc.getSourceRange();
454 // if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
455 // Diag(History->getLocStart(), diag_no_history);
457 IdentifierInfo& II = ASTCtx->Idents.get("reactions");
458 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
459 for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
460 Reactions.first != Reactions.second; ++Reactions.first, typedef_num++)
461 HandleReaction(*Reactions.first, RecordDecl);
462 if(typedef_num == 0) {
463 Diag(RecordDecl->getLocStart(), diag_warning)
464 << RecordDecl->getName() << "state has no typedef for reactions";
465 state->setNoTypedef();
469 void handleStateMachine(CXXRecordDecl *RecordDecl, const CXXBaseSpecifier *Base)
471 Model::Machine m(RecordDecl->getName());
472 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
474 if (MyCXXRecordDecl *InitialState =
475 static_cast<MyCXXRecordDecl*>(getTemplateArgDeclOfBase(Base, 1)))
476 m.setInitialState(InitialState->getName());
480 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
482 if (!Declaration->isCompleteDefinition())
484 if (Declaration->getQualifiedNameAsString() == "boost::statechart::state" ||
485 Declaration->getQualifiedNameAsString() == "TimedState" ||
486 Declaration->getQualifiedNameAsString() == "TimedSimpleState")
487 return true; // This is an "abstract class" not a real state
489 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
490 const CXXBaseSpecifier *Base;
492 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
493 handleSimpleState(RecordDecl, Base);
494 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
495 handleStateMachine(RecordDecl, Base);
496 else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
498 //sc.events.push_back(RecordDecl->getNameAsString());
505 class VisualizeStatechartConsumer : public clang::ASTConsumer
511 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
512 DiagnosticsEngine &D)
513 : visitor(Context, model, D), destFileName(destFileName) {}
515 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
516 visitor.TraverseDecl(Context.getTranslationUnitDecl());
517 model.write_as_dot_file(destFileName);
521 class VisualizeStatechartAction : public PluginASTAction
524 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
525 size_t dot = getCurrentFile().find_last_of('.');
526 std::string dest = getCurrentFile().substr(0, dot);
528 return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
531 bool ParseArgs(const CompilerInstance &CI,
532 const std::vector<std::string>& args) {
533 for (unsigned i = 0, e = args.size(); i != e; ++i) {
534 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
536 // Example error handling.
537 if (args[i] == "-an-error") {
538 DiagnosticsEngine &D = CI.getDiagnostics();
539 unsigned DiagID = D.getCustomDiagID(
540 DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
545 if (args.size() && args[0] == "help")
546 PrintHelp(llvm::errs());
550 void PrintHelp(llvm::raw_ostream& ros) {
551 ros << "Help for Visualize Statechart plugin goes here\n";
556 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");