File size: 1,617 Bytes
79b94f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import openai
import json
import graphviz
import streamlit as st

class MindMap:
    
    def __init__(self):
        openai.api_key = os.getenv("OPENAI_API_KEY")
        
    def get_connections(self, text_chunks_libs:dict) -> list:
        
        state_prompt = open("./prompts/mindmap.prompt")
        PROMPT = state_prompt.read()
        state_prompt.close()
        
        final_connections = []
        for key in text_chunks_libs:
            for text_chunk in text_chunks_libs[key]:
                PROMPT = PROMPT.replace("$prompt", text_chunk)
                
                response = openai.Completion.create(
                    engine="text-davinci-003",
                    prompt = PROMPT,
                    temperature=0.5,
                    max_tokens=2048,
                    top_p=1,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                )
                
                relationships = response.choices[0].text
                final_string = '{"relations":' + relationships + '}'
                data = json.loads(final_string)
                relations = data["relations"]
                final_connections.extend(relations)
        return final_connections
            
        
    def generate_graph(self, text_chunks_libs:dict):
        graph = graphviz.Digraph()
        all_connections = self.get_connections(text_chunks_libs)
        for connection in all_connections:
            from_node = connection[0]
            to_node = connection[2]
            graph.edge(from_node, to_node)
        st.graphviz_chart(graph)