2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/Support/raw_os_ostream.h"
31 #include "clang/AST/ASTConsumer.h"
32 #include "clang/AST/ASTContext.h"
33 #include "clang/AST/CXXInheritance.h"
34 #include "clang/AST/RecursiveASTVisitor.h"
35 #include "clang/Frontend/CompilerInstance.h"
36 #include "clang/Frontend/FrontendPluginRegistry.h"
38 using namespace clang;
44 inline int getIndentLevelIdx() {
45 static int i = ios_base::xalloc();
49 ostream& indent(ostream& os) { os << setw(2*os.iword(getIndentLevelIdx())) << ""; return os; }
50 ostream& indent_inc(ostream& os) { os.iword(getIndentLevelIdx())++; return os; }
51 ostream& indent_dec(ostream& os) { os.iword(getIndentLevelIdx())--; return os; }
55 class Context : public map<string, State*> {
57 iterator add(State *state);
58 Context *findContext(const string &name);
61 class State : public Context
63 string initialInnerState;
66 explicit State(string name) : name(name) {}
67 void setInitialInnerState(string name) { initialInnerState = name; }
68 friend ostream& operator<<(ostream& os, const State& s);
72 Context::iterator Context::add(State *state)
74 pair<iterator, bool> ret = insert(value_type(state->name, state));
78 Context *Context::findContext(const string &name)
80 iterator i = find(name), e;
83 for (i = begin(), e = end(); i != e; ++i) {
84 Context *c = i->second->findContext(name);
92 ostream& operator<<(ostream& os, const Context& c);
94 ostream& operator<<(ostream& os, const State& s)
96 os << indent << "" << s.name << "\n";
98 os << indent << s.name << " -> " << s.initialInnerState << " [style = dashed]\n";
99 os << indent << "subgraph cluster_" << s.name << " {\n" << indent_inc;
100 os << indent << "label = \"" << s.name << "\"\n";
101 os << indent << s.initialInnerState << " [peripheries=2]\n";
102 os << static_cast<Context>(s);
103 os << indent_dec << indent << "}\n";
109 ostream& operator<<(ostream& os, const Context& c)
111 for (Context::const_iterator i = c.begin(), e = c.end(); i != e; i++) {
121 const string src, dst, event;
122 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
125 ostream& operator<<(ostream& os, const Transition& t)
127 os << indent << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
132 class Machine : public Context
135 string initial_state;
138 explicit Machine(string name) : name(name) {}
140 void setInitialState(string name) { initial_state = name; }
142 friend ostream& operator<<(ostream& os, const Machine& m);
145 ostream& operator<<(ostream& os, const Machine& m)
147 os << indent << "subgraph " << m.name << " {\n" << indent_inc;
148 os << indent << m.initial_state << " [peripheries=2]\n";
149 os << static_cast<Context>(m);
150 os << indent_dec << indent << "}\n";
155 class Model : public map<string, Machine>
157 Context unknown; // For forward-declared state classes
159 list< Transition*> transitions;
161 iterator add(const Machine &m)
163 pair<iterator, bool> ret = insert(value_type(m.name, m));
167 void addUnknownState(State *m)
169 unknown[m->name] = m;
173 Context *findContext(const string &name)
175 Context::iterator ci = unknown.find(name);
176 if (ci != unknown.end())
178 iterator i = find(name), e;
181 for (i = begin(), e = end(); i != e; ++i) {
182 Context *c = i->second.findContext(name);
189 State *removeFromUnknownContexts(const string &name)
191 Context::iterator ci = unknown.find(name);
192 if (ci == unknown.end())
198 void write_as_dot_file(string fn)
200 ofstream f(fn.c_str());
201 f << "digraph statecharts {\n" << indent_inc;
202 for (iterator i = begin(), e = end(); i != e; i++)
204 for (list<Transition*>::iterator t = transitions.begin(), e = transitions.end(); t != e; ++t)
206 f << indent_dec << "}\n";
212 class MyCXXRecordDecl : public CXXRecordDecl
214 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
218 string qn(static_cast<const char*>(qualName));
219 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
221 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
222 return canon->getQualifiedNameAsString() == qn;
226 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
227 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
228 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
229 if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
232 *Base = Paths.front().back().Base;
238 class Visitor : public RecursiveASTVisitor<Visitor>
242 DiagnosticsEngine &Diags;
243 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
244 diag_found_state, diag_found_statemachine, diag_no_history;
247 bool shouldVisitTemplateInstantiations() const { return true; }
249 explicit Visitor(ASTContext *Context, Model::Model &model, DiagnosticsEngine &Diags)
250 : ASTCtx(Context), model(model), Diags(Diags)
252 diag_found_statemachine =
253 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
255 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
256 diag_unhandled_reaction_type =
257 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
258 diag_unhandled_reaction_decl =
259 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
260 diag_unhandled_reaction_decl =
261 Diags.getCustomDiagID(DiagnosticsEngine::Error, "History is not yet supported");
264 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
266 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
268 // TODO: Improve Loc tracking
269 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
270 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
271 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
272 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
273 if (name == "boost::statechart::transition") {
274 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
275 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
276 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
277 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
279 Model::Transition *T = new Model::Transition(SrcState->getName(), DstState->getName(), Event->getName());
280 model.transitions.push_back(T);
281 } else if (name == "boost::statechart::custom_reaction") {
282 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
283 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
285 Model::Transition *T = new Model::Transition(SrcState->getName(), "\"??? custom\"", Event->getName());
286 model.transitions.push_back(T);
287 } else if (name == "boost::statechart::deferral") {
288 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
289 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
291 Model::Transition *T = new Model::Transition(SrcState->getName(), "\"??? deferral\"", Event->getName());
292 model.transitions.push_back(T);
293 } else if (name == "boost::mpl::list") {
294 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
295 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
297 Diag(Loc, diag_unhandled_reaction_type) << name;
299 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
302 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
304 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
305 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
306 r->getLocStart(), SrcState);
308 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
311 CXXRecordDecl *getTemplateArgDecl(const Type *T, unsigned ArgNum)
313 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
314 return getTemplateArgDecl(ET->getNamedType().getTypePtr(), ArgNum);
315 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
316 if (TST->getNumArgs() >= ArgNum+1)
317 return TST->getArg(ArgNum).getAsType()->getAsCXXRecordDecl();
323 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
325 if (!Declaration->isCompleteDefinition())
328 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
329 const CXXBaseSpecifier *Base;
331 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state", &Base))
333 string name(RecordDecl->getName()); //getQualifiedNameAsString());
334 Diag(RecordDecl->getLocStart(), diag_found_state) << name;
337 // Either we saw a reference to forward declared state
338 // before, or we create a new state.
339 if (!(state = model.removeFromUnknownContexts(name)))
340 state = new Model::State(name);
342 CXXRecordDecl *Context = getTemplateArgDecl(Base->getType().getTypePtr(), 1);
343 Model::Context *c = model.findContext(Context->getName());
345 Model::State *s = new Model::State(Context->getName());
346 model.addUnknownState(s);
351 if (CXXRecordDecl *InnerInitialState = getTemplateArgDecl(Base->getType().getTypePtr(), 2))
352 state->setInitialInnerState(InnerInitialState->getName());
354 // if (CXXRecordDecl *History = getTemplateArgDecl(Base->getType().getTypePtr(), 3))
355 // Diag(History->getLocStart(), diag_no_history);
357 IdentifierInfo& II = ASTCtx->Idents.get("reactions");
358 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
359 for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
360 Reactions.first != Reactions.second; ++Reactions.first)
361 HandleReaction(*Reactions.first, RecordDecl);
363 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
365 Model::Machine m(RecordDecl->getName());
366 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << m.name;
368 if (CXXRecordDecl *InitialState = getTemplateArgDecl(Base->getType().getTypePtr(), 1))
369 m.setInitialState(InitialState->getName());
372 else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
374 //sc.events.push_back(RecordDecl->getNameAsString());
381 class VisualizeStatechartConsumer : public clang::ASTConsumer
387 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
388 DiagnosticsEngine &D)
389 : visitor(Context, model, D), destFileName(destFileName) {}
391 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
392 visitor.TraverseDecl(Context.getTranslationUnitDecl());
393 model.write_as_dot_file(destFileName);
397 class VisualizeStatechartAction : public PluginASTAction
400 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
401 size_t dot = getCurrentFile().find_last_of('.');
402 std::string dest = getCurrentFile().substr(0, dot);
404 return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
407 bool ParseArgs(const CompilerInstance &CI,
408 const std::vector<std::string>& args) {
409 for (unsigned i = 0, e = args.size(); i != e; ++i) {
410 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
412 // Example error handling.
413 if (args[i] == "-an-error") {
414 DiagnosticsEngine &D = CI.getDiagnostics();
415 unsigned DiagID = D.getCustomDiagID(
416 DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
421 if (args.size() && args[0] == "help")
422 PrintHelp(llvm::errs());
426 void PrintHelp(llvm::raw_ostream& ros) {
427 ros << "Help for Visualize Statechart plugin goes here\n";
432 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");