kgraph / app.py
Gary0205's picture
Update app.py
a2574cd verified
raw
history blame contribute delete
No virus
5.89 kB
import gradio as gr
from py2neo import Graph
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WikipediaLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import json
# Set up the connection to the Neo4j database
url = "neo4j+s://ddb8863b.databases.neo4j.io"
username = "neo4j"
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
graph = Graph(url, auth=(username, password))
neo4j_graph = Neo4jGraph(url=url, username=username, password=password)
# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
# Function to extract relations from model output
def extract_relations_from_model_output(text):
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets
extracted_triplets = extract_triplets(extracted_text[0])
print(extracted_triplets)
return extracted_triplets
class KB():
def __init__(self):
self.relations = []
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def add_relation(self, r):
if not self.exists_relation(r):
self.relations.append(r)
def print(self):
print("Relations:")
for r in self.relations:
print(f" {r}")
def from_small_text_to_kb(text, verbose=False):
kb = KB()
model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
if verbose:
print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
print("Tokens are done")
gen_kwargs = {
"max_length": 216,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": 3
}
generated_tokens = model.generate(
**model_inputs,
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
print("Before for loop")
for sentence_pred in decoded_preds:
relations = extract_relations_from_model_output(sentence_pred)
print(len(relations))
for r in relations:
kb.add_relation(r)
return kb
# Function to insert data into Neo4j from Wikipedia query
def insert_data_from_wikipedia(query):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, length_function=len, is_separator_regex=False)
raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)
if not raw_documents:
print("No documents found for query:", query)
return False
for doc in raw_documents:
kb = from_small_text_to_kb(doc.page_content, verbose=True)
for relation in kb.relations:
head = relation['head']
relationship = relation['type']
tail = relation['tail']
if head and relationship and tail:
cypher = f"MERGE (h:`{head}`) MERGE (t:`{tail}`) MERGE (h)-[:`{relationship}`]->(t)"
print(f"Executing Cypher query: {cypher}") # Debug print for Cypher query
graph.run(cypher)
else:
print(f"Skipping invalid relation: head='{head}', relationship='{relationship}', tail='{tail}'") # Skip invalid relations
return True
# Function to query the database
def query_neo4j(query):
if not query.strip():
return json.dumps({"error": "Empty Cypher query"}, indent=2) # Handle empty query case
try:
result = graph.run(query).data()
return json.dumps(result, indent=2) # Convert to JSON string
except Exception as e:
return json.dumps({"error": str(e)}, indent=2) # Return error as JSON
# Gradio interface function
def gradio_interface(wiki_query, cypher_query):
if not wiki_query.strip():
return json.dumps({"error": "Wikipedia query cannot be empty"}, indent=2)
success = insert_data_from_wikipedia(wiki_query)
if not success:
return json.dumps({"error": f"No data found for Wikipedia query: {wiki_query}"}, indent=2)
if not cypher_query.strip():
return json.dumps({"error": "Cypher query cannot be empty"}, indent=2)
result = query_neo4j(cypher_query)
return result
# Create the Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=["text", "text"],
outputs="json",
title="Neo4j and Wikipedia Interface",
description="Insert data from Wikipedia and query the Neo4j database."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()