]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
4d19a3fe399d33c180168593cb16515bef754a02
[boost-statechart-viewer.git] / src / visualizer.cpp
1 /** @file */
2 ////////////////////////////////////////////////////////////////////////////////////////
3 //
4 //    This file is part of Boost Statechart Viewer.
5 //
6 //    Boost Statechart Viewer is free software: you can redistribute it and/or modify
7 //    it under the terms of the GNU General Public License as published by
8 //    the Free Software Foundation, either version 3 of the License, or
9 //    (at your option) any later version.
10 //
11 //    Boost Statechart Viewer is distributed in the hope that it will be useful,
12 //    but WITHOUT ANY WARRANTY; without even the implied warranty of
13 //    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 //    GNU General Public License for more details.
15 //
16 //    You should have received a copy of the GNU General Public License
17 //    along with Boost Statechart Viewer.  If not, see <http://www.gnu.org/licenses/>.
18 //
19 ////////////////////////////////////////////////////////////////////////////////////////
20
21 //standard header files
22 #include <fstream>
23
24 //LLVM Header files
25 #include "llvm/Support/raw_ostream.h"
26 #include "llvm/Support/raw_os_ostream.h"
27
28 //clang header files
29 #include "clang/AST/ASTConsumer.h"
30 #include "clang/AST/CXXInheritance.h"
31 #include "clang/AST/RecursiveASTVisitor.h"
32 #include "clang/Frontend/CompilerInstance.h"
33 #include "clang/Frontend/FrontendPluginRegistry.h"
34
35 using namespace clang;
36 using namespace std;
37
38 class Statechart
39 {
40 public:
41     class Transition
42     {
43     public:
44         string src, dst, event;
45         Transition(string src, string dst, string event) : src(src), dst(dst), event(event) {}
46     };
47     string name;
48     string name_of_start;
49     list<Transition> transitions;
50     list<string> cReactions; /** list of custom reactions. After all files are traversed this list should be empty. */
51     list<string> events;
52     list<string> states;
53
54     void write_dot_file(string fn)
55     {
56         ofstream f(fn.c_str());
57         f << "digraph " << name << " {\n";
58         for (string& s : states) {
59             f << "  " << s << "\n";
60         }
61
62         for (Transition &t : transitions) {
63             f << t.src << " -> " << t.dst << " [label = \"" << t.event << "\"]\n";
64         }
65
66         f << "}";
67     }
68 };
69
70
71 class MyCXXRecordDecl : public CXXRecordDecl
72 {
73     static bool FindBaseClassString(const CXXBaseSpecifier *Specifier,
74                                     CXXBasePath &Path,
75                                     void *qualName)
76     {
77         string qn(static_cast<const char*>(qualName));
78         const RecordType *rt = Specifier->getType()->getAs<RecordType>();
79         assert(rt);
80         TagDecl *canon = rt->getDecl()->getCanonicalDecl();
81         return canon->getQualifiedNameAsString() == qn;
82     }
83
84 public:
85     bool isDerivedFrom(const char *baseStr) const {
86         CXXBasePaths Paths(/*FindAmbiguities=*/false, /*RecordPaths=*/false, /*DetectVirtual=*/false);
87         Paths.setOrigin(const_cast<MyCXXRecordDecl*>(this));
88         return lookupInBases(&FindBaseClassString, const_cast<char*>(baseStr), Paths);
89     }
90 };
91
92
93 class Visitor : public RecursiveASTVisitor<Visitor>
94 {
95     ASTContext *Context;
96     Statechart &sc;
97     DiagnosticsEngine &Diags;
98     unsigned diag_unhandled_reaction_type, diag_unhandled_reaction_decl,
99         diag_found_state, diag_found_statemachine;
100
101 public:
102     bool shouldVisitTemplateInstantiations() const { return true; }
103
104     explicit Visitor(ASTContext *Context, Statechart &sc, DiagnosticsEngine &Diags)
105         : Context(Context), sc(sc), Diags(Diags)
106     {
107         diag_found_statemachine =
108             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found statemachine '%0'");
109         diag_found_state =
110             Diags.getCustomDiagID(DiagnosticsEngine::Note, "Found state '%0'");
111         diag_unhandled_reaction_type =
112             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction type '%0'");
113         diag_unhandled_reaction_decl =
114             Diags.getCustomDiagID(DiagnosticsEngine::Error, "Unhandled reaction decl '%0'");
115     }
116
117     DiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID) { return Diags.Report(Loc, DiagID); }
118
119     void HandleReaction(const Type *T, const SourceLocation Loc, CXXRecordDecl *SrcState)
120     {
121         if (const ElaboratedType *ET = dyn_cast<ElaboratedType>(T))
122             HandleReaction(ET->getNamedType().getTypePtr(), Loc, SrcState);
123         else if (const TemplateSpecializationType *TST = dyn_cast<TemplateSpecializationType>(T)) {
124             string name = TST->getTemplateName().getAsTemplateDecl()->getQualifiedNameAsString();
125             if (name == "boost::statechart::transition") {
126                 const Type *EventType = TST->getArg(0).getAsType().getTypePtr();
127                 const Type *DstStateType = TST->getArg(1).getAsType().getTypePtr();
128                 CXXRecordDecl *Event = EventType->getAsCXXRecordDecl();
129                 CXXRecordDecl *DstState = DstStateType->getAsCXXRecordDecl();
130
131                 sc.transitions.push_back(Statechart::Transition(SrcState->getName(), DstState->getName(),
132                                                                 Event->getName()));
133             } else if (name == "boost::mpl::list") {
134                 for (TemplateSpecializationType::iterator Arg = TST->begin(), End = TST->end(); Arg != End; ++Arg)
135                     HandleReaction(Arg->getAsType().getTypePtr(), Loc, SrcState);
136             }
137             //->getDecl()->getQualifiedNameAsString();
138         } else
139             Diag(Loc, diag_unhandled_reaction_type) << T->getTypeClassName();
140     }
141
142     void HandleReaction(const NamedDecl *Decl, CXXRecordDecl *SrcState)
143     {
144         if (const TypedefDecl *r = dyn_cast<TypedefDecl>(Decl))
145             HandleReaction(r->getCanonicalDecl()->getUnderlyingType().getTypePtr(),
146                            r->getLocStart(), SrcState);
147         else
148             Diag(Decl->getLocation(), diag_unhandled_reaction_decl) << Decl->getDeclKindName();
149     }
150
151
152     bool VisitCXXRecordDecl(CXXRecordDecl *Declaration)
153     {
154         if (!Declaration->isCompleteDefinition())
155             return true;
156
157         MyCXXRecordDecl *RecordDecl = static_cast<MyCXXRecordDecl*>(Declaration);
158
159         if (RecordDecl->isDerivedFrom("boost::statechart::simple_state"))
160         {
161             string state(RecordDecl->getName()); //getQualifiedNameAsString());
162             Diag(RecordDecl->getLocStart(), diag_found_state) << state;
163             sc.states.push_back(state);
164
165             IdentifierInfo& II = Context->Idents.get("reactions");
166             for (DeclContext::lookup_result Reactions = RecordDecl->lookup(DeclarationName(&II));
167                  Reactions.first != Reactions.second; ++Reactions.first)
168                 HandleReaction(*Reactions.first, RecordDecl);
169         }
170         else if (RecordDecl->isDerivedFrom("boost::statechart::state_machine"))
171         {
172             sc.name = RecordDecl->getQualifiedNameAsString();
173             sc.name_of_start = "tmp"; //RecordDecl->getStateMachineInitialStateAsString()
174             Diag(RecordDecl->getLocStart(), diag_found_statemachine) << sc.name;
175         }
176         else if (RecordDecl->isDerivedFrom("boost::statechart::event"))
177         {
178             sc.events.push_back(RecordDecl->getNameAsString());
179         }
180         return true;
181     }
182 };
183
184
185 class VisualizeStatechartConsumer : public clang::ASTConsumer
186 {
187     Statechart statechart;
188     Visitor visitor;
189     string destFileName;
190 public:
191     explicit VisualizeStatechartConsumer(ASTContext *Context, std::string destFileName,
192                                          DiagnosticsEngine &D)
193         : visitor(Context, statechart, D), destFileName(destFileName) {}
194
195     virtual void HandleTranslationUnit(clang::ASTContext &Context) {
196         visitor.TraverseDecl(Context.getTranslationUnitDecl());
197         statechart.write_dot_file(destFileName);
198     }
199 };
200
201 class VisualizeStatechartAction : public PluginASTAction
202 {
203 protected:
204   ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) {
205     size_t dot = getCurrentFile().find_last_of('.');
206     std::string dest = getCurrentFile().substr(0, dot);
207     dest.append(".dot");
208     return new VisualizeStatechartConsumer(&CI.getASTContext(), dest, CI.getDiagnostics());
209   }
210
211   bool ParseArgs(const CompilerInstance &CI,
212                  const std::vector<std::string>& args) {
213     for (unsigned i = 0, e = args.size(); i != e; ++i) {
214       llvm::errs() << "Visualizer arg = " << args[i] << "\n";
215
216       // Example error handling.
217       if (args[i] == "-an-error") {
218         DiagnosticsEngine &D = CI.getDiagnostics();
219         unsigned DiagID = D.getCustomDiagID(
220           DiagnosticsEngine::Error, "invalid argument '" + args[i] + "'");
221         D.Report(DiagID);
222         return false;
223       }
224     }
225     if (args.size() && args[0] == "help")
226       PrintHelp(llvm::errs());
227
228     return true;
229   }
230   void PrintHelp(llvm::raw_ostream& ros) {
231     ros << "Help for Visualize Statechart plugin goes here\n";
232   }
233
234 };
235
236 static FrontendPluginRegistry::Add<VisualizeStatechartAction> X("visualize-statechart", "visualize statechart");
237
238 // Local Variables:
239 // c-basic-offset: 4
240 // End: