File size: 5,045 Bytes
231b6fa
 
 
 
d48fa34
231b6fa
773e892
d48fa34
231b6fa
3fc7538
 
231b6fa
6762d68
d48fa34
 
 
 
 
 
 
 
a18a0b9
 
 
 
 
d48fa34
 
 
 
6762d68
d48fa34
 
 
c917abb
d48fa34
 
 
 
fb9c35d
d48fa34
 
 
231b6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c710c0
3fc7538
231b6fa
9c710c0
231b6fa
773e892
 
d48fa34
773e892
d48fa34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231b6fa
 
773e892
 
 
 
9c710c0
 
231b6fa
9c710c0
 
231b6fa
9c710c0
 
 
 
 
efe5c60
231b6fa
3fc7538
231b6fa
9c710c0
82531de
 
 
 
 
efe5c60
82531de
 
 
 
 
efe5c60
9c710c0
3fc7538
231b6fa
 
c69421d
 
9c710c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""CodeImportsAnalyzer uses the ast module from Python's standard library
to get what modules are imported in given python files, then uses networkx to generate imports graph
"""
import ast
import asyncio

import aiohttp
import pybase64

from .graph_analyzer import GraphAnalyzer


def construct_fetch_program_text_api_url(api_url):
    import os

    # to increase api rate limiting
    # https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
    USER = os.environ.get("USER", "")
    PERSONAL_ACCESS_TOKEN = os.environ.get("PERSONAL_ACCESS_TOKEN", "")

    if USER and PERSONAL_ACCESS_TOKEN:
        protocol, api_url_components = api_url.split("://")
        new_api_url_components = f"{USER}:{PERSONAL_ACCESS_TOKEN}@{api_url_components}"
        return f"{protocol}://{new_api_url_components}"
    else:
        return api_url


async def get_program_text(session, python_file):
    async with session.get(
        construct_fetch_program_text_api_url(python_file["url"]),
        headers={"Accept": "application/vnd.github.v3+json"},
    ) as response:
        data = await response.json()
        print(data)
        if data["encoding"] == "base64":
            return data["content"]
        else:
            print(
                f"WARNING: {python_file['path']}'s encoding is {data['encoding']}, not base64"
            )


class CodeImportsAnalyzer:
    class _NodeVisitor(ast.NodeVisitor):
        def __init__(self, imports):
            self.imports = imports

        def visit_Import(self, node):
            for alias in node.names:
                self.imports[-1]["imports"].append(
                    {"module": None, "name": alias.name, "level": -1}
                )
            self.generic_visit(node)

        def visit_ImportFrom(self, node):
            for alias in node.names:
                self.imports[-1]["imports"].append(
                    {"module": node.module, "name": alias.name, "level": node.level}
                )
            self.generic_visit(node)

    def __init__(self, python_files):
        self.python_imports = []
        self.graph_analyzer = GraphAnalyzer(is_directed=True)
        self.python_files = python_files
        self._node_visitor = CodeImportsAnalyzer._NodeVisitor(self.python_imports)

    async def analyze(self):
        async with aiohttp.ClientSession() as session:
            tasks = []
            for python_file in self.python_files:
                self.python_imports += [
                    {
                        "file_name": python_file["path"].split("/")[-1],
                        "file_path": python_file["path"],
                        "imports": [],
                    }
                ]
                tasks.append(
                    asyncio.ensure_future(get_program_text(session, python_file))
                )

            base64_program_texts = await asyncio.gather(*tasks)
            for base64_program_text in base64_program_texts:
                program = pybase64.b64decode(base64_program_text)
                tree = ast.parse(program)
                self._node_visitor.visit(tree)

    def generate_imports_graph(self):
        # TODO: thought on how to improve the graph generation logic
        # generate a dictionary of lists data structure
        # generate a graph based on a dictionary of lists

        for python_import in self.python_imports:
            _nodes = python_import["file_path"].split("/")
            if len(_nodes):
                # generate graph based on file_path
                # node/edge relationship means file/folder structure
                if len(_nodes) > 1:
                    # make last node and second last node as one node
                    # to solve the issue of duplicated file names using only last node
                    if len(_nodes) >= 3:
                        _nodes[-2] = _nodes[-2] + "/" + _nodes[-1]
                        del _nodes[-1]
                    self.graph_analyzer.add_edges_from_nodes(_nodes)
                else:
                    self.graph_analyzer.add_node(_nodes[0])

                # generate graph based on imported modules in each file
                if python_import["file_name"] != "__init__.py":
                    for _import in python_import["imports"]:
                        if _import["module"] is None:
                            _import_names = _import["name"].split(".")
                            _new_nodes = _import_names + [_nodes[-1]]
                            self.graph_analyzer.add_edges_from_nodes(_new_nodes)
                        else:
                            _import_names = _import["module"].split(".") + [
                                _import["name"]
                            ]
                            _new_nodes = _import_names + [_nodes[-1]]
                            self.graph_analyzer.add_edges_from_nodes(_new_nodes)

        return self.graph_analyzer.graph

    def report(self):
        from pprint import pprint

        pprint(self.python_imports)