]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
Working with custom reactions. For one custom reaction can exist more than one transi...
[boost-statechart-viewer.git] / src / visualizer.cpp
1 //standard header files
2 #include <iostream>
3 #include <string>
4 #include <fstream>
5 #include <list>
6
7 //LLVM Header files
8 #include "llvm/Support/raw_ostream.h"
9 #include "llvm/Support/Host.h"
10 #include "llvm/Config/config.h"
11
12 //clang header files
13 #include "clang/Frontend/TextDiagnosticPrinter.h"
14 #include "clang/Lex/HeaderSearch.h"
15 #include "clang/Basic/FileManager.h"
16 #include "clang/Frontend/Utils.h"
17 #include "clang/Basic/TargetInfo.h"
18 #include "clang/Lex/Preprocessor.h"
19 #include "clang/Frontend/CompilerInstance.h"
20 #include "clang/AST/ASTConsumer.h"
21 #include "clang/Sema/Lookup.h"
22 #include "clang/Parse/ParseAST.h"
23 #include "clang/Basic/Version.h"
24 #include "clang/Driver/Driver.h"
25 #include "clang/Driver/Compilation.h"
26
27 //my own header files
28 #include "stringoper.h"
29
30 using namespace clang;
31 using namespace clang::driver;
32 using namespace std;
33
34 class MyDiagnosticClient : public TextDiagnosticPrinter // My diagnostic Client
35 {
36         public:
37         MyDiagnosticClient(llvm::raw_ostream &os, const DiagnosticOptions &diags, bool OwnsOutputStream = false):TextDiagnosticPrinter(os, diags, OwnsOutputStream = false){}
38         virtual void HandleDiagnostic(Diagnostic::Level DiagLevel, const DiagnosticInfo &Info)
39         {
40                 TextDiagnosticPrinter::HandleDiagnostic(DiagLevel, Info); // print diagnostic information
41                 if(DiagLevel > 2) // if error/fatal error stop the program
42                 {               
43                         exit(1);
44                 }       
45         }
46 };
47
48 class FindStates : public ASTConsumer
49 {
50         list<string> transitions;
51         list<string> cReactions;
52         list<string> events;
53         string name_of_machine;
54         string name_of_start;
55         StringDecl sd;  
56         int nbrStates;
57         FullSourceLoc *fsloc;
58         public:
59         list<string> states;
60         
61         virtual void Initialize(ASTContext &ctx)//run after the AST is constructed before the consumer starts to work
62         {       
63                 fsloc = new FullSourceLoc(* new SourceLocation(), ctx.getSourceManager());
64                 name_of_start = "";
65                 name_of_machine = "";
66                 nbrStates = 0;
67         }
68
69         virtual void HandleTopLevelDecl(DeclGroupRef DGR)// traverse all top level declarations
70         {
71                 SourceLocation loc;
72       std::string line, output, event;
73                 llvm::raw_string_ostream x(output);
74                 for (DeclGroupRef::iterator i = DGR.begin(), e = DGR.end(); i != e; ++i) 
75                 {
76                         const Decl *decl = *i;
77                         loc = decl->getLocation();
78                         if(loc.isValid())
79                         {
80                                 //cout<<decl->getKind()<<"ss\n";
81                                 if(decl->getKind()==35)
82                                 {                                       
83                                         method_decl(decl);
84                                 }
85                                 if (const TagDecl *tagDecl = dyn_cast<TagDecl>(decl))
86                                 {
87                                         if(tagDecl->isStruct() || tagDecl->isClass()) //is it a struct or class 
88                                         {
89                                                 struct_class(decl);
90                                         }
91                                 }       
92                                 if(const NamespaceDecl *namespaceDecl = dyn_cast<NamespaceDecl>(decl))
93                                 {
94                                         
95                                         DeclContext *declCont = namespaceDecl->castToDeclContext(namespaceDecl);
96                                         //cout<<namedDecl->getNameAsString()<<"   sss\n";
97                                         recursive_visit(declCont);
98                                 
99                                 }
100                         }
101                         output = "";
102                 }
103         }
104         void recursive_visit(const DeclContext *declCont) //recursively visit all decls hidden inside namespaces
105         {
106       std::string line, output, event;
107                 llvm::raw_string_ostream x(output);
108                 SourceLocation loc;
109                 for (DeclContext::decl_iterator i = declCont->decls_begin(), e = declCont->decls_end(); i != e; ++i)
110                 {
111                         const Decl *decl = *i;
112                         //std::cout<<"a "<<decl->getDeclKindName()<<"\n";
113                         loc = decl->getLocation();
114                         if(loc.isValid())
115                         {       
116                                 if(decl->getKind()==35)
117                                 {
118                                         method_decl(decl);
119                                 }
120                                 else if (const TagDecl *tagDecl = dyn_cast<TagDecl>(decl))
121                                 {
122                                         if(tagDecl->isStruct() || tagDecl->isClass()) //is it a structure or class      
123                                         {
124                                                 struct_class(decl);
125                                         }       
126                                 }
127                                 else if(const NamespaceDecl *namespaceDecl = dyn_cast<NamespaceDecl>(decl))
128                                 {
129                                         DeclContext *declCont = namespaceDecl->castToDeclContext(namespaceDecl);
130                                         //cout<<namedDecl->getNameAsString()<<"  sss\n";                        
131                                         recursive_visit(declCont);
132                                 }
133                         }
134                         output = "";
135                 } 
136         }
137                 
138         void struct_class(const Decl *decl) // works with struct or class decl
139         {
140                 string output, line, ret, trans, event; 
141                 llvm::raw_string_ostream x(output);
142                 decl->print(x);
143                 line = sd.get_line_of_code(x.str());
144                 output = "";
145                 int pos, num;
146                 const TagDecl *tagDecl = dyn_cast<TagDecl>(decl);
147                 const NamedDecl *namedDecl = dyn_cast<NamedDecl>(decl);         
148                 if(sd.is_derived(line))
149                 {
150                         const CXXRecordDecl *cRecDecl = dyn_cast<CXXRecordDecl>(decl);
151                                         
152                         if(sd.find_events(cRecDecl, line))
153                         {
154                                 events.push_back(namedDecl->getNameAsString());
155                                 cout<<"New event: "<<namedDecl->getNameAsString()<<"\n";
156                         }
157                         else if(name_of_machine == "")
158                         {
159                                 ret = sd.find_name_of_machine(cRecDecl, line);
160                                 if(!ret.empty())
161                                 {
162                                         pos = ret.find(",");
163                                         name_of_machine = ret.substr(0,pos);
164                                         name_of_start = ret.substr(pos+1);
165                                         cout<<"Name of the state machine: "<<name_of_machine<<"\n";
166                                         cout<<"Name of the first state: "<<name_of_start<<"\n";
167                                 }
168                         }
169                         else
170                         {
171                                 ret = sd.find_states(cRecDecl, line);   
172                                 if(!ret.empty())
173                                 {                               
174                                         const DeclContext *declCont = tagDecl->castToDeclContext(tagDecl);              
175                                         //states.push_back(namedDecl->getNameAsString());
176                                         std::cout << "New state: " << namedDecl->getNameAsString() << "\n";
177                                         states.push_back(ret);
178                                         output="";
179                                         for (DeclContext::decl_iterator i = declCont->decls_begin(), e = declCont->decls_end(); i != e; ++i) 
180                                         {
181                                                 const Decl *decl = *i;
182                                                 if (decl->getKind()==26) 
183                                                 {
184                                                         decl->print(x);
185                                                         output = x.str();
186                                                         line = sd.clean_spaces(sd.cut_type(output));            
187                                                         ret = sd.find_transitions(namedDecl->getNameAsString(),line);
188                                                         if(!ret.empty()) 
189                                                         {
190                                                                 num = sd.count(ret,';')+1;
191                                                                 for(int i = 0;i<num;i++)
192                                                                 {
193                                                                         pos = ret.find(";");
194                                                                         if(pos == 0)
195                                                                         {
196                                                                                 ret = ret.substr(1);
197                                                                                 pos = ret.find(";");
198                                                                                 if(pos==-1) cReactions.push_back(ret);
199                                                                                 else cReactions.push_back(ret.substr(0,pos));   
200                                                                                 num-=1;
201                                                                         }
202                                                                         else 
203                                                                         {
204                                                                                 if(pos==-1) transitions.push_back(ret);
205                                                                                 else transitions.push_back(ret.substr(0,pos));
206                                                                         }
207                                                                         //cout<<ret<<"\n";
208                                                                         if(i!=num-1) ret = ret.substr(pos+1);
209                                                                 }
210                                                                 output="";
211                                                         }
212                                                 }
213                                                 if(decl->getKind()==35) method_decl(decl);
214                                         }
215                                 }
216                         }
217                 }
218         }
219         void method_decl(const Decl *decl)
220         {
221                 string output, line, event;     
222                 llvm::raw_string_ostream x(output);
223                 if(decl->hasBody())
224                 {
225                         decl->print(x);
226                         line = sd.get_return(x.str());
227                         if(sd.test_model(line,"result"))
228                         {
229                                 const FunctionDecl *fDecl = dyn_cast<FunctionDecl>(decl);
230                                 const ParmVarDecl *pvd = fDecl->getParamDecl(0);
231                                 QualType qt = pvd->getOriginalType();                           
232                                 event = qt.getAsString();
233                                 if(event[event.length()-1]=='&') event = event.substr(0,event.length()-2);
234                                 event = event.substr(event.rfind(" ")+1);
235                                 line = dyn_cast<NamedDecl>(decl)->getQualifiedNameAsString();
236                                 line = sd.cut_namespaces(line.substr(0,line.rfind("::")));
237                                 line.append(",");
238                                 line.append(event);
239                                 find_return_stmt(decl->getBody(),line); 
240                                 for(list<string>::iterator i = cReactions.begin();i!=cReactions.end();i++)
241                                 {
242                                         event = *i;
243                                         if(line.compare(event)==0) 
244                                         {
245                                                 cReactions.erase(i);
246                                                 break;
247                                         }
248                                 }       
249                         }
250                 }
251         }
252         void find_return_stmt(Stmt *statemt,string event)
253         {
254                 if(statemt->getStmtClass() == 99) test_stmt(dyn_cast<CaseStmt>(statemt)->getSubStmt(), event);
255                 else
256                 {
257                         for (Stmt::child_range range = statemt->children(); range; ++range)    
258                         {
259                                 test_stmt(*range, event);
260                         }
261                 }
262         }
263         
264         void test_stmt(Stmt *stmt, string event)
265         {
266                 const SourceManager &sman = fsloc->getManager();
267                 int type;
268                 string line, param;
269                 type = stmt->getStmtClass();
270                 switch(type)
271                 {       
272                         case 8 :                find_return_stmt(dyn_cast<DoStmt>(stmt)->getBody(), event); // do
273                                                         break;
274                         case 86 :       find_return_stmt(dyn_cast<ForStmt>(stmt)->getBody(), event); // for
275                                                         break;
276                         case 88 :   find_return_stmt(dyn_cast<IfStmt>(stmt)->getThen(), event); //if then
277                                                         find_return_stmt(dyn_cast<IfStmt>(stmt)->getElse(), event); //if else
278                                                         break;
279                         case 90 :       find_return_stmt(dyn_cast<LabelStmt>(stmt)->getSubStmt(), event); //label
280                                                         break;
281                         case 98 :       line = sman.getCharacterData(dyn_cast<ReturnStmt>(stmt)->getReturnLoc()); 
282                                                         line = sd.get_line_of_code(line).substr(6);
283                                                         line = line.substr(0,line.find("("));
284                                                         if(sd.test_model(line,"transit"))
285                                                         {
286                                                                 param = sd.get_params(line);
287                                                                 transitions.push_back(event.append(",").append(param));
288                                                         }
289                                                         break;
290                         case 99 :       find_return_stmt(stmt, event);
291                                                         break;
292                         case 101 :      find_return_stmt(dyn_cast<SwitchStmt>(stmt)->getBody(), event); // switch
293                                                         break;
294                         case 102 :      find_return_stmt(dyn_cast<WhileStmt>(stmt)->getBody(), event); // while
295                                                         break;
296                         }
297         }
298
299         void save_to_file(std::string output) // save all to the output file
300         {
301                 nbrStates = states.size();
302                 string state, str, context, ctx;
303                 int pos1, pos2, cnt, subs;
304                 ofstream filestr(output.c_str());
305                 //std::cout<<output<<"\n";
306                 filestr<<"digraph "<< name_of_machine<< " {\n";
307                 context = name_of_machine;
308                 for(list<string>::iterator i = states.begin();i!=states.end();i++) // write all states in the context of the automaton
309                 {
310                         state = *i;
311                         cnt = sd.count(state,',');
312                         if(cnt==1)
313                         {
314                                 pos1 = state.find(",");
315                                 ctx = sd.cut_namespaces(state.substr(pos1+1));
316                                 //std::cout<<name_of_machine.length();                          
317                                 if(ctx.compare(0,context.length(),context)==0)
318                                 {
319                                         filestr<<sd.cut_namespaces(state.substr(0,pos1))<<";\n";
320                                         states.erase(i);
321                                         i--;
322                                 }
323                         }
324                         if(cnt==2)
325                         {
326                                 pos1 = state.find(",");
327                                 pos2 = state.rfind(",");
328                                 ctx = sd.cut_namespaces(state.substr(pos1+1,pos2-pos1-1));
329                                 //std::cout<<ctx<<" "<<context<<"\n";
330                                 if(ctx.compare(0,context.length(),context)==0)
331                                 {                               
332                                         filestr<<sd.cut_namespaces(state.substr(0,pos1))<<";\n";
333                                 }
334                         }
335                 }
336                 filestr<<name_of_start<<" [peripheries=2] ;\n";
337                 subs = 0;
338                 while(!states.empty()) // substates ?
339                 {
340                         state = states.front();
341                         filestr<<"subgraph cluster"<<subs<<" {\n";                      
342                         pos1 = state.find(",");
343                         pos2 = state.rfind(",");
344                         context = sd.cut_namespaces(state.substr(0,pos1));
345                         filestr<<"label=\""<<context<<"\";\n";
346                         filestr<<sd.cut_namespaces(state.substr(pos2+1))<<" [peripheries=2] ;\n";       
347                         states.pop_front();     
348                         //std::cout<<states.size();     
349                         for(list<string>::iterator i = states.begin();i!=states.end();i++)
350                         {
351                                 state = *i;
352                                 cnt = sd.count(state,',');
353                                 //std::cout<<state<<" \n";
354                                 if(cnt==1)
355                                 {
356                                         pos1 = state.find(",");
357                                         ctx = sd.cut_namespaces(state.substr(pos1+1));
358                                         if(ctx.compare(0,context.length(),context)==0)
359                                         {
360                                                 filestr<<sd.cut_namespaces(state.substr(0,pos1))<<";\n";
361                                                 states.erase(i);
362                                                 i--;
363                                         }
364                                 }
365                                 if(cnt==2)
366                                 {
367                                         pos1 = state.find(",");
368                                         pos2 = state.rfind(",");
369                                         ctx = sd.cut_namespaces(state.substr(pos1+1,pos2-pos1-1));
370                                         if(ctx.compare(0,context.length(),context)==0) filestr<<sd.cut_namespaces(state.substr(0,pos1))<<";\n";
371                                 }
372                         }
373                         filestr<<"}\n";
374                         subs+=1;        
375                 }               
376                 for(list<string>::iterator i = transitions.begin();i!=transitions.end();i++) // write all transitions
377                 {
378                         state = *i;
379                         pos1 = state.find(",");
380                         filestr<<sd.cut_namespaces(state.substr(0,pos1))<<"->";
381                         pos2 = state.rfind(",");
382                         filestr<<sd.cut_namespaces(state.substr(pos2+1));
383                         filestr<<"[label=\""<<sd.cut_namespaces(state.substr(pos1+1,pos2-pos1-1))<<"\"];\n";
384                 }               
385                 filestr<<"}";
386                 filestr.close();
387         }
388         void print_stats() // print statistics
389         {
390                 cout<<"\n"<<"Statistics: \n";
391                 cout<<"Number of states: "<<nbrStates<<"\n";
392                 cout<<"Number of events: "<<events.size()<<"\n";
393                 cout<<"Number of transitions: "<<transitions.size()<<"\n";
394                 return;
395         }
396
397 };
398
399 int main(int argc, char **argv)
400
401         string inputFilename = "";
402         string outputFilename = "graph.dot"; // initialize output Filename
403         MyDiagnosticClient *mdc = new MyDiagnosticClient(llvm::errs(), * new DiagnosticOptions());
404         llvm::IntrusiveRefCntPtr<DiagnosticIDs> dis(new DiagnosticIDs());       
405         Diagnostic diag(dis,mdc);
406         FileManager fm( * new FileSystemOptions());
407         SourceManager sm (diag, fm);
408         HeaderSearch *headers = new HeaderSearch(fm);
409         
410         Driver TheDriver(LLVM_PREFIX "/bin", llvm::sys::getHostTriple(), "", false, false, diag);
411         TheDriver.setCheckInputsExist(true);
412         TheDriver.CCCIsCXX = 1; 
413         CompilerInvocation compInv;
414         llvm::SmallVector<const char *, 16> Args(argv, argv + argc);
415         llvm::OwningPtr<Compilation> C(TheDriver.BuildCompilation(Args.size(),
416                                                             Args.data()));
417         const driver::JobList &Jobs = C->getJobs();
418         const driver::Command *Cmd = cast<driver::Command>(*Jobs.begin());
419         const driver::ArgStringList &CCArgs = Cmd->getArguments();
420         for(unsigned i = 0; i<Args.size();i++) // find -o in ArgStringList
421         {       
422                 if(strncmp(Args[i],"-o",2)==0) 
423                 {
424                         if(strlen(Args[i])>2)
425                         {
426                                 string str = Args[i];
427                                 outputFilename = str.substr(2);
428                         }
429                         else outputFilename = Args[i+1];
430                         break;
431                 }
432         }
433                 
434         CompilerInvocation::CreateFromArgs(compInv,
435                                           const_cast<const char **>(CCArgs.data()),
436                                           const_cast<const char **>(CCArgs.data())+CCArgs.size(),
437                                           diag);
438
439         HeaderSearchOptions hsopts = compInv.getHeaderSearchOpts();
440         hsopts.ResourceDir = LLVM_PREFIX "/lib/clang/" CLANG_VERSION_STRING;
441         LangOptions lang = compInv.getLangOpts();
442         CompilerInvocation::setLangDefaults(lang, IK_ObjCXX);
443         TargetInfo *ti = TargetInfo::CreateTargetInfo(diag, compInv.getTargetOpts());
444         ApplyHeaderSearchOptions(*headers, hsopts, lang, ti->getTriple());
445         FrontendOptions f = compInv.getFrontendOpts();
446         inputFilename = f.Inputs[0].second;
447
448         cout<<"Input filename: "<<inputFilename<<"\n"; // print Input filename
449         cout<<"Output filename: "<<outputFilename<<"\n"; // print Output filename
450
451
452         Preprocessor pp(diag, lang, *ti, sm, *headers);
453         pp.getBuiltinInfo().InitializeBuiltins(pp.getIdentifierTable(), lang);
454                 
455         InitializePreprocessor(pp, compInv.getPreprocessorOpts(),hsopts,f);
456         
457         const FileEntry *file = fm.getFile(inputFilename);
458         sm.createMainFileID(file);
459         IdentifierTable tab(lang);
460         Builtin::Context builtins(*ti);
461         FindStates c;
462         ASTContext ctx(lang, sm, *ti, tab, * new SelectorTable(), builtins,0);
463         mdc->BeginSourceFile(lang, &pp);//start using diagnostic
464         ParseAST(pp, &c, ctx, false, false);
465         mdc->EndSourceFile(); //end using diagnostic
466         if(c.states.size()>0) c.save_to_file(outputFilename);
467         else cout<<"No state machine was found\n";
468         c.print_stats();
469         return 0;
470 }