almithal / mindmap.py
lordvader31's picture
major update from bitbucket
79b94f8
raw
history blame
1.69 kB
import os
import openai
import json
import graphviz
import streamlit as st
class MindMap:
def __init__(self):
openai.api_key = os.getenv("OPENAI_API_KEY")
def get_connections(self, text_chunks_libs:dict) -> list:
state_prompt = open("./prompts/mindmap.prompt")
PROMPT = state_prompt.read()
state_prompt.close()
final_connections = []
for key in text_chunks_libs:
for text_chunk in text_chunks_libs[key]:
PROMPT = PROMPT.replace("$prompt", text_chunk)
response = openai.Completion.create(
engine="text-davinci-003",
prompt = PROMPT,
temperature=0.5,
max_tokens=2048,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
)
relationships = response.choices[0].text
final_string = '{"relations":' + relationships + '}'
data = json.loads(final_string)
print(data)
relations = data["relations"]
final_connections.extend(relations)
print(final_connections)
return final_connections
def generate_graph(self, text_chunks_libs:dict):
graph = graphviz.Digraph()
all_connections = self.get_connections(text_chunks_libs)
for connection in all_connections:
from_node = connection[0]
to_node = connection[2]
graph.edge(from_node, to_node)
st.graphviz_chart(graph)