2 ////////////////////////////////////////////////////////////////////////////////////////
4 // This file is part of Boost Statechart Viewer.
6 // Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published by
8 // the Free Software Foundation, either version 3 of the License, or
9 // (at your option) any later version.
11 // Boost Statechart Viewer is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
16 // You should have received a copy of the GNU General Public License
17 // along with Boost Statechart Viewer. If not, see <http://www.gnu.org/licenses/>.
19 ////////////////////////////////////////////////////////////////////////////////////////
21 //standard header files
25 #include "llvm/Support/raw_ostream.h"
26 #include "llvm/Support/raw_os_ostream.h"
29 #include "clang/AST/ASTConsumer.h"
30 #include "clang/AST/CXXInheritance.h"
31 #include "clang/AST/RecursiveASTVisitor.h"
32 #include "clang/Frontend/CompilerInstance.h"
33 #include "clang/Frontend/FrontendPluginRegistry.h"
35 using namespace clang;
44 string src, dst, event;
45 Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
49 list<Transition> transitions;
50 list<string> cReactions; /** list of custom reactions. After all files are traversed this list should be empty. */
54 void write_dot_file(string fn)
56 ofstream f(fn.c_str());
57 f << "digraph " << name << " {\n";
58 f << " " << name_of_start << " [peripheries=2]\n";
59 for (string& s : states) {
60 f << " " << s << "\n";
63 for (Transition &t : transitions) {
64 f << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
72 class MyCXXRecordDecl : public CXXRecordDecl
74 static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
78 string qn(static_cast<const char*>(qualName));
79 const RecordType *rt = Specifier->getType()->getAs<RecordType>();
81 TagDecl *canon = rt->getDecl()->getCanonicalDecl();
82 return canon->getQualifiedNameAsString() == qn;
86 bool isDerivedFrom(const char *baseStr, CXXBaseSpecifier const **Base = 0) const {
87 CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/!!Base, /*DetectVirtual=*/false);
88 Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
89 if (!lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths))
92 *Base = Paths.front().back().Base;
98 class Visitor : public RecursiveASTVisitor<Visitor>
102 DiagnosticsEngine &Diags;
103 unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
104 diag_found_state, diag_found_statemachine;
107 bool shouldVisitTemplateInstantiations() const { return true; }
109 explicit Visitor(ASTContext *Context, Statechart &sc, DiagnosticsEngine &Diags)
110 : Context(Context), sc(sc), Diags(Diags)
112 diag_found_statemachine =
113 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
115 Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
116 diag_unhandled_reaction_type =
117 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
118 diag_unhandled_reaction_decl =
119 Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
122 DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
124 void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
126 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
127 HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
128 else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
129 string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
130 if (name == "boost::statechart::transition") {
131 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
132 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
133 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
134 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
136 sc.transitions.push_back(Statechart::Transition(SrcState->getName(), DstState->getName(),
138 } else if (name == "boost::mpl::list") {
139 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
140 HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
142 //->getDecl()->getQualifiedNameAsString();
144 Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
147 void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
149 if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
150 HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
151 r->getLocStart(), SrcState);
153 Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
157 bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
159 if (!Declaration->isCompleteDefinition())
162 MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
163 const CXXBaseSpecifier *Base;
165 if (RecordDecl->isDerivedFrom("boost::statechart::simple_state"))
167 string state(RecordDecl->getName()); //getQualifiedNameAsString());
168 Diag(RecordDecl->getLocStart(), diag_found_state) << state;
169 sc.states.push_back(state);
171 IdentifierInfo& II = Context->Idents.get("reactions");
172 // TODO: Lookup for reactions even in base classes - probably by using Sema::LookupQualifiedName()
173 for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
174 Reactions.first != Reactions.second; ++Reactions.first)
175 HandleReaction(*Reactions.first, RecordDecl);
177 else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine", &Base))
179 sc.name = RecordDecl->getQualifiedNameAsString();
180 Diag(RecordDecl->getLocStart(), diag_found_statemachine) << sc.name;
182 if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(Base->getType())) {
183 if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(ET->getNamedType())) {
184 sc.name_of_start = TST->getArg(1).getAsType()->getAsCXXRecordDecl()->getName();
188 else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
190 sc.events.push_back(RecordDecl->getNameAsString());
197 class VisualizeStatechartConsumer : public clang::ASTConsumer
199 Statechart statechart;
203 explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
204 DiagnosticsEngine &D)
205 : visitor(Context, statechart, D), destFileName(destFileName) {}
207 virtual void HandleTranslationUnit(clang::ASTContext &Context) {
208 visitor.TraverseDecl(Context.getTranslationUnitDecl());
209 statechart.write_dot_file(destFileName);
213 class VisualizeStatechartAction : public PluginASTAction
216 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
217 size_t dot = getCurrentFile().find_last_of('.');
218 std::string dest = getCurrentFile().substr(0, dot);
220 return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
223 bool ParseArgs(const CompilerInstance &CI,
224 const std::vector<std::string>& args) {
225 for (unsigned i = 0, e = args.size(); i != e; ++i) {
226 llvm::errs() << "Visualizer arg = " << args[i] << "\n";
228 // Example error handling.
229 if (args[i] == "-an-error") {
230 DiagnosticsEngine &D = CI.getDiagnostics();
231 unsigned DiagID = D.getCustomDiagID(
232 DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
237 if (args.size() && args[0] == "help")
238 PrintHelp(llvm::errs());
242 void PrintHelp(llvm::raw_ostream& ros) {
243 ros << "Help for Visualize Statechart plugin goes here\n";
248 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");