]> rtime.felk.cvut.cz Git - boost-statechart-viewer.git/blob - src/visualizer.cpp
Merge branch 'master' of rtime.felk.cvut.cz:boost-statechart-viewer
[boost-statechart-viewer.git] / src / visualizer.cpp
1 #include <iostream>
2 #include <string>
3 #include <fstream>
4 #include <list>
5
6
7 #include "llvm/Support/raw_ostream.h"
8 #include "llvm/System/Host.h"
9 #include "llvm/Config/config.h"
10
11 #include "clang/Frontend/DiagnosticOptions.h"
12 #include "clang/Frontend/TextDiagnosticPrinter.h"
13
14 #include "clang/Basic/LangOptions.h"
15 #include "clang/Basic/FileSystemOptions.h"
16
17 #include "clang/Index/TranslationUnit.h"
18 #include "clang/Basic/SourceManager.h"
19 #include "clang/Lex/HeaderSearch.h"
20 #include "clang/Basic/FileManager.h"
21
22 #include "clang/Frontend/HeaderSearchOptions.h"
23 #include "clang/Frontend/Utils.h"
24
25 #include "clang/Basic/TargetOptions.h"
26 #include "clang/Basic/TargetInfo.h"
27
28 #include "clang/Lex/Preprocessor.h"
29 #include "clang/Frontend/PreprocessorOptions.h"
30 #include "clang/Frontend/FrontendOptions.h"
31
32 #include "clang/Frontend/CompilerInvocation.h"
33
34 #include "clang/Basic/IdentifierTable.h"
35 #include "clang/Basic/Builtins.h"
36
37 #include "clang/AST/ASTContext.h"
38 #include "clang/AST/ASTConsumer.h"
39 #include "clang/Sema/Sema.h"
40 #include "clang/AST/DeclBase.h"
41 #include "clang/AST/Type.h"
42 #include "clang/AST/Decl.h"
43 #include "clang/Sema/Lookup.h"
44 #include "clang/Sema/Ownership.h"
45 #include "clang/AST/DeclGroup.h"
46
47 #include "clang/Parse/Parser.h"
48
49 #include "clang/Parse/ParseAST.h"
50 #include "clang/Basic/Version.h"
51
52 #include "llvm/Support/CommandLine.h"
53
54 //my own header files
55 #include "stringoper.h"
56 #include "commandlineopt.h"
57
58 using namespace clang;
59
60
61 class MyDiagnosticClient : public TextDiagnosticPrinter
62 {
63         public:
64         MyDiagnosticClient(llvm::raw_ostream &os, const DiagnosticOptions &diags, bool OwnsOutputStream = false):TextDiagnosticPrinter(os, diags, OwnsOutputStream = false){}
65         virtual void HandleDiagnostic(Diagnostic::Level DiagLevel, const DiagnosticInfo &Info)
66         {
67                 TextDiagnosticPrinter::HandleDiagnostic(DiagLevel, Info);
68                 if(DiagLevel == 3) exit(1);
69         }
70 };
71
72 class FindStates : public ASTConsumer
73 {
74         std::list<string> transitions;
75         std::list<string> states;
76         std::string name_of_machine;
77         std::string name_of_start;
78         FullSourceLoc *fSloc;
79         public:
80
81         virtual void Initialize(ASTContext &ctx)//run after the AST is constructed
82         {       
83                 SourceLocation loc;
84                 name_of_start = "";
85                 name_of_machine = "";
86                 SourceManager &sman = ctx.getSourceManager();
87                 fSloc = new FullSourceLoc(loc, sman);
88         }
89
90         virtual void HandleTopLevelDecl(DeclGroupRef DGR)// traverse all top level declarations
91         {
92                 const SourceManager &sman = fSloc->getManager();
93                 SourceLocation loc;
94                 std::string line;
95                 std::string super_class, output;
96                 llvm::raw_string_ostream x(output);
97                 for (DeclGroupRef::iterator i = DGR.begin(), e = DGR.end(); i != e; ++i) 
98                 {
99                         const Decl *decl = *i;
100                         loc = decl->getLocation();
101                         if(loc.isValid())
102                         {
103                                 const NamedDecl *namedDecl = dyn_cast<NamedDecl>(decl);
104                                 //std::cout<<decl->getDeclKindName()<<"\n";
105                                 if (const TagDecl *tagDecl = dyn_cast<TagDecl>(decl))
106                                 {
107                                         if(tagDecl->isStruct() || tagDecl->isClass()) //is it a structure or class      
108                                         {
109                                                 const CXXRecordDecl *cRecDecl = dyn_cast<CXXRecordDecl>(decl);
110                                                 decl->print(x);
111                                                 //decl->dump();                                                 
112                                                 line = cut_commentary(clean_spaces(get_line_of_code(x.str())));
113                                                 output = "";
114                                                 if(is_derived(line))
115                                                 {
116                                                         if(name_of_machine == "")
117                                                         {
118                                                                 find_name_of_machine(cRecDecl, line);
119                                                         }
120                                                         else
121                                                         {
122                                                                 if(find_states(cRecDecl, line))
123                                                                 {                               
124                                                                         const DeclContext *declCont = tagDecl->castToDeclContext(tagDecl);                                      
125                                                                         std::cout << "New state: " << namedDecl->getNameAsString() << "\n";
126                                                                         find_transitions(namedDecl->getNameAsString(), declCont);
127                                                                 }
128                                                         }
129                                                 }
130                                         }
131                                 }       
132                                 if(const NamespaceDecl *namespaceDecl = dyn_cast<NamespaceDecl>(decl))
133                                 {
134                                         DeclContext *declCont = namespaceDecl->castToDeclContext(namespaceDecl);
135                                         //declCont->dumpDeclContext();                          
136                                         recursive_visit(declCont);
137                                 
138                                 }
139                         }
140                 }
141         }
142         void recursive_visit(const DeclContext *declCont) //recursively visit all decls hidden inside namespaces
143         {
144                 const SourceManager &sman = fSloc->getManager();
145                 std::string line, output;
146                 SourceLocation loc;
147                 llvm::raw_string_ostream x(output);
148                 for (DeclContext::decl_iterator i = declCont->decls_begin(), e = declCont->decls_end(); i != e; ++i)
149                 {
150                         const Decl *decl = *i;
151                         const NamedDecl *namedDecl = dyn_cast<NamedDecl>(decl);
152                         
153                         //std::cout<<"a "<<decl->getDeclKindName()<<"\n";
154                         loc = decl->getLocation();
155                         if(loc.isValid())
156                         {                       
157                                 if (const TagDecl *tagDecl = dyn_cast<TagDecl>(decl))
158                                 {
159                                         if(tagDecl->isStruct() || tagDecl->isClass()) //is it a structure or class      
160                                         {
161                                                 const CXXRecordDecl *cRecDecl = dyn_cast<CXXRecordDecl>(decl);
162                                                 decl->print(x);
163                                                 line = cut_commentary(clean_spaces(get_line_of_code(x.str())));
164                                                 output = "";
165                                                 if(is_derived(line))
166                                                 {
167                                                         if(name_of_machine == "")
168                                                         {
169                                                                 find_name_of_machine(cRecDecl, line);
170                                                         }
171                                                         else
172                                                         {
173                                                                 if(find_states(cRecDecl, line))
174                                                                 {                               
175                                                                         const DeclContext *declCont = tagDecl->castToDeclContext(tagDecl);              
176                                                                         //states.push_back(namedDecl->getNameAsString());
177                                                                         std::cout << "New state: " << namedDecl->getNameAsString() << "\n";
178                                                                         find_transitions(namedDecl->getNameAsString(), declCont);
179                                                                 }
180                                                         }
181                                                 }
182                                         }       
183                                 }
184                                 if(const NamespaceDecl *namespaceDecl = dyn_cast<NamespaceDecl>(decl))
185                                 {
186                                         DeclContext *declCont = namespaceDecl->castToDeclContext(namespaceDecl);
187                                         //declCont->dumpDeclContext();                          
188                                         recursive_visit(declCont);
189                                 }
190                         }
191                 } 
192         }
193         bool find_states(const CXXRecordDecl *cRecDecl, std::string line) // test if the struct/class is the state (must be derived from simple_state)
194         {       
195                 std::string super_class = get_super_class(line), base;          
196                 if(cRecDecl->getNumBases()>1)
197                 {
198                         for(unsigned i = 0; i<cRecDecl->getNumBases();i++ )
199                         {
200                                 if(i!=cRecDecl->getNumBases()-1) base = get_first_base(super_class);
201                                 else base = super_class;
202                                 if(is_state(super_class)) 
203                                 {
204                                         //std::cout<<get_params(super_class);
205                                         states.push_back(get_params(super_class));
206                                         return true;
207                                 }
208                                 else
209                                 {
210                                         super_class = get_next_base(super_class);
211                                 }
212                         }
213                         return false;
214                 }
215                 else
216                 { 
217                         if(is_state(super_class)) 
218                         {
219                                 //std::cout<<get_params(super_class);
220                                 states.push_back(get_params(super_class));
221                                 return true;
222                         }
223                         else return false;
224                 }
225         }
226                 
227         void find_name_of_machine(const CXXRecordDecl *cRecDecl, std::string line) // find name of the state machine and the start state
228         {       
229                 std::string super_class = get_super_class(line), base, params;
230                 
231                 int pos = 0;
232                 if(cRecDecl->getNumBases()>1)
233                 {
234                         for(unsigned i = 0; i<cRecDecl->getNumBases();i++ )
235                         {
236                                 if(i!=cRecDecl->getNumBases()-1) base = get_first_base(super_class);
237                                 else base = super_class;
238                                 if(is_machine(base))
239                                 {
240                                         params = get_params(base);
241                                         pos = params.find(",");
242                                         name_of_machine = params.substr(0,pos);
243                                         name_of_start = params.substr(pos);
244                                         std::cout<<"Name of the state machine: "<<name_of_machine<<"\n";
245                                         std::cout<<"Name of the first state: "<<name_of_start<<"\n";
246                                 }
247                                 else
248                                 {
249                                         super_class = get_next_base(super_class);
250                                 }
251                         }
252                 }
253                 else
254                 { 
255                         if(is_machine(super_class))
256                         {
257                                 //std::cout<<super_class;
258                                 params = get_params(super_class);
259                                 //std::cout<<params;
260                                 pos = params.find(",");
261                                 name_of_machine = cut_namespaces(params.substr(0,pos));
262                                 name_of_start = cut_namespaces(params.substr(pos+1));
263                                 std::cout<<"Name of the state machine:"<<name_of_machine<<"\n";
264                                 std::cout<<"Name of the first state:"<<name_of_start<<"\n";
265                         }
266                 }
267         }
268
269         void find_transitions (const std::string name_of_state,const DeclContext *declCont) // traverse all methods for finding declarations of transitions
270         {       
271                 std::string output, line, dest, params, base;   
272                 llvm::raw_string_ostream x(output);
273                 int num;                
274                 for (DeclContext::decl_iterator i = declCont->decls_begin(), e = declCont->decls_end(); i != e; ++i) 
275                 {
276                         const Decl *decl = *i;
277                         if (const TypedefDecl *typedDecl = dyn_cast<TypedefDecl>(decl)) 
278                         {
279                                         decl->print(x);
280                                         output = x.str();
281                                         line = clean_spaces(cut_typedef(output));
282                                         num = count(output,'<');
283                                         if(num>1)
284                                         {
285                                                 num-=1;
286                                                 if(is_list(line))
287                                                 {
288                                                         line = get_inner_part(line);
289                                                 }
290                                         }
291                                         for(int j = 0;j<num;j++)
292                                         {
293                                                 if(j!=num-1) base = get_first_base(line);                       
294                                                 else base = line;
295                                                 if(is_transition(base))
296                                                 {
297                                                         dest = name_of_state;
298                                                         params = get_params(base);
299                                                         dest.append(",");                                                       
300                                                         dest.append(params);
301                                                         transitions.push_back(dest);
302                                                         line = get_next_base(line);
303                                                 }
304                                                 else
305                                                 {
306                                                         line = get_next_base(line);
307                                                 }
308                                         }
309                                         output = "";
310                         }
311                 }       
312         }
313         
314         void save_to_file(std::string output)
315         {
316                 std::string state, str, context, ctx;
317                 int pos1, pos2, cnt, subs;
318                 std::ofstream filestr(output.c_str());
319                 std::cout<<output<<"\n";
320                 filestr<<"digraph "<< name_of_machine<< " {\n";
321                 context = name_of_machine;
322                 for(list<string>::iterator i = states.begin();i!=states.end();i++) // write all states in the context of the automaton
323                 {
324                         state = *i;
325                         cnt = count(state,',');
326                         if(cnt==1)
327                         {
328                                 pos1 = state.find(",");
329                                 ctx = cut_namespaces(state.substr(pos1+1));
330                                 //std::cout<<name_of_machine.length();                          
331                                 if(ctx.compare(0,context.length(),context)==0)
332                                 {
333                                         filestr<<cut_namespaces(state.substr(0,pos1))<<";\n";
334                                         states.erase(i);
335                                         i--;
336                                 }
337                         }
338                         if(cnt==2)
339                         {
340                                 pos1 = state.find(",");
341                                 pos2 = state.rfind(",");
342                                 ctx = cut_namespaces(state.substr(pos1+1,pos2-pos1-1));
343                                 //std::cout<<ctx<<" "<<context<<"\n";
344                                 if(ctx.compare(0,context.length(),context)==0)
345                                 {                               
346                                         filestr<<cut_namespaces(state.substr(0,pos1))<<";\n";
347                                 }
348                         }
349                 }
350                 filestr<<name_of_start<<" [peripheries=2] ;\n";
351                 subs = 0;
352                 while(!states.empty()) // substates ?
353                 {
354                         state = states.front();
355                         filestr<<"subgraph cluster"<<subs<<" {\n";                      
356                         pos1 = state.find(",");
357                         pos2 = state.rfind(",");
358                         context = cut_namespaces(state.substr(0,pos1));
359                         filestr<<"label=\""<<context<<"\";\n";
360                         filestr<<cut_namespaces(state.substr(pos2+1))<<" [peripheries=2] ;\n";  
361                         states.pop_front();     
362                         //std::cout<<states.size();     
363                         for(list<string>::iterator i = states.begin();i!=states.end();i++)
364                         {
365                                 state = *i;
366                                 cnt = count(state,',');
367                                 //std::cout<<state<<" \n";
368                                 if(cnt==1)
369                                 {
370                                         pos1 = state.find(",");
371                                         ctx = cut_namespaces(state.substr(pos1+1));
372                                         
373                                         //std::cout<<ctx<<" "<<context<<"\n";
374                                         if(ctx.compare(0,context.length(),context)==0)
375                                         {
376                                                 filestr<<cut_namespaces(state.substr(0,pos1))<<";\n";
377                                                 states.erase(i);
378                                                 i--;
379                                         }
380                                 }
381                                 if(cnt==2)
382                                 {
383                                         pos1 = state.find(",");
384                                         pos2 = state.rfind(",");
385                                         ctx = cut_namespaces(state.substr(pos1+1,pos2-pos1-1));
386                                         if(ctx.compare(0,context.length(),context)==0)
387                                         {                               
388                                                 filestr<<cut_namespaces(state.substr(0,pos1))<<";\n";
389                                                 //std::cout<<ctx<<"\n";
390                                         }
391                                 }
392                         }
393                         filestr<<"}\n";
394                         subs+=1;        
395                 }               
396                 for(list<string>::iterator i = transitions.begin();i!=transitions.end();i++) // write all transitions
397                 {
398                         state = *i;
399                         pos1 = state.find(",");
400                         filestr<<cut_namespaces(state.substr(0,pos1))<<"->";
401                         pos2 = state.rfind(",");
402                         filestr<<cut_namespaces(state.substr(pos2+1));
403                         filestr<<"[label=\""<<cut_namespaces(state.substr(pos1+1,pos2-pos1-1))<<"\"];\n";
404                 }               
405                 filestr<<"}";
406                 filestr.close();
407         }
408 };
409
410 int main(int argc, char *argv[])
411 {
412         llvm::cl::ParseCommandLineOptions(argc, argv);  
413         std::cout<<"Input file: "<<inputFilename<<"\n"; 
414         FILE* fileI = fopen(inputFilename.c_str(), "r");
415         if (!fileI)  
416         {
417                 perror(inputFilename.c_str());
418         exit(1);
419         }
420         fclose(fileI);
421         DiagnosticOptions diagnosticOptions;
422         llvm::IntrusiveRefCntPtr<DiagnosticIDs> dis(new DiagnosticIDs());
423         MyDiagnosticClient *mdc = new MyDiagnosticClient(llvm::outs(), diagnosticOptions);
424         Diagnostic diag(dis,mdc);
425         diag.setIgnoreAllWarnings(true);
426         FileSystemOptions fileSysOpt;     
427         LangOptions lang;
428         lang.BCPLComment=1;
429         lang.CPlusPlus=1; 
430         FileManager fm (fileSysOpt);
431
432         SourceManager sm ( diag, fm);
433         HeaderSearch *headers = new HeaderSearch(fm);
434         CompilerInvocation::setLangDefaults(lang, IK_ObjCXX);
435
436         HeaderSearchOptions hsopts;
437         hsopts.ResourceDir=LLVM_PREFIX "/lib/clang/" CLANG_VERSION_STRING;
438         for(unsigned i = 0; i<includeFiles.size();i++)
439         {
440                 hsopts.AddPath(includeFiles[i],
441                                 clang::frontend::Angled,
442                                 false,
443                                 false,
444                                 true);
445         }
446         TargetOptions to;
447         to.Triple = llvm::sys::getHostTriple();
448         TargetInfo *ti = TargetInfo::CreateTargetInfo(diag, to);
449         clang::ApplyHeaderSearchOptions(
450                 *headers,
451                 hsopts,
452                 lang,
453                 ti->getTriple());
454         Preprocessor pp(diag, lang, *ti, sm, *headers);
455         pp.getBuiltinInfo().InitializeBuiltins(pp.getIdentifierTable(),
456                                            pp.getLangOptions());
457         FrontendOptions f;
458         PreprocessorOptions ppio;
459         InitializePreprocessor(pp, ppio,hsopts,f);
460         const FileEntry *file = fm.getFile(inputFilename);
461         sm.createMainFileID(file);
462         IdentifierTable tab(lang);
463         SelectorTable sel;
464         Builtin::Context builtins(*ti);
465         FindStates c;
466         ASTContext ctx(lang, sm, *ti, tab, sel, builtins,0);
467         mdc->BeginSourceFile(lang, &pp);
468         ParseAST(pp, &c, ctx, false, false);
469         mdc->EndSourceFile();
470         c.save_to_file(outputFile);
471         return 0;
472
473 }