|
import gradio as gr |
|
from py2neo import Graph |
|
from langchain_community.graphs.neo4j_graph import Neo4jGraph |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders import WikipediaLoader |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import json |
|
|
|
|
|
url = "neo4j+s://ddb8863b.databases.neo4j.io" |
|
username = "neo4j" |
|
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0" |
|
graph = Graph(url, auth=(username, password)) |
|
neo4j_graph = Neo4jGraph(url=url, username=username, password=password) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") |
|
|
|
|
|
def extract_relations_from_model_output(text): |
|
triplets = [] |
|
relation, subject, relation, object_ = '', '', '', '' |
|
text = text.strip() |
|
current = 'x' |
|
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split(): |
|
if token == "<triplet>": |
|
current = 't' |
|
if relation != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
relation = '' |
|
subject = '' |
|
elif token == "<subj>": |
|
current = 's' |
|
if relation != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
object_ = '' |
|
elif token == "<obj>": |
|
current = 'o' |
|
relation = '' |
|
else: |
|
if current == 't': |
|
subject += ' ' + token |
|
elif current == 's': |
|
object_ += ' ' + token |
|
elif current == 'o': |
|
relation += ' ' + token |
|
if subject != '' and relation != '' and object_ != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
return triplets |
|
extracted_triplets = extract_triplets(extracted_text[0]) |
|
print(extracted_triplets) |
|
return extracted_triplets |
|
|
|
class KB(): |
|
def __init__(self): |
|
self.relations = [] |
|
def are_relations_equal(self, r1, r2): |
|
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) |
|
def exists_relation(self, r1): |
|
return any(self.are_relations_equal(r1, r2) for r2 in self.relations) |
|
def add_relation(self, r): |
|
if not self.exists_relation(r): |
|
self.relations.append(r) |
|
def print(self): |
|
print("Relations:") |
|
for r in self.relations: |
|
print(f" {r}") |
|
|
|
def from_small_text_to_kb(text, verbose=False): |
|
kb = KB() |
|
model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt') |
|
if verbose: |
|
print(f"Num tokens: {len(model_inputs['input_ids'][0])}") |
|
print("Tokens are done") |
|
gen_kwargs = { |
|
"max_length": 216, |
|
"length_penalty": 0, |
|
"num_beams": 3, |
|
"num_return_sequences": 3 |
|
} |
|
generated_tokens = model.generate( |
|
**model_inputs, |
|
**gen_kwargs, |
|
) |
|
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) |
|
print("Before for loop") |
|
for sentence_pred in decoded_preds: |
|
relations = extract_relations_from_model_output(sentence_pred) |
|
print(len(relations)) |
|
for r in relations: |
|
kb.add_relation(r) |
|
return kb |
|
|
|
|
|
def insert_data_from_wikipedia(query): |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, length_function=len, is_separator_regex=False) |
|
raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter) |
|
|
|
if not raw_documents: |
|
print("No documents found for query:", query) |
|
return False |
|
|
|
for doc in raw_documents: |
|
kb = from_small_text_to_kb(doc.page_content, verbose=True) |
|
for relation in kb.relations: |
|
head = relation['head'] |
|
relationship = relation['type'] |
|
tail = relation['tail'] |
|
if head and relationship and tail: |
|
cypher = f"MERGE (h:`{head}`) MERGE (t:`{tail}`) MERGE (h)-[:`{relationship}`]->(t)" |
|
print(f"Executing Cypher query: {cypher}") |
|
graph.run(cypher) |
|
else: |
|
print(f"Skipping invalid relation: head='{head}', relationship='{relationship}', tail='{tail}'") |
|
return True |
|
|
|
|
|
def query_neo4j(query): |
|
if not query.strip(): |
|
return json.dumps({"error": "Empty Cypher query"}, indent=2) |
|
try: |
|
result = graph.run(query).data() |
|
return json.dumps(result, indent=2) |
|
except Exception as e: |
|
return json.dumps({"error": str(e)}, indent=2) |
|
|
|
|
|
def gradio_interface(wiki_query, cypher_query): |
|
if not wiki_query.strip(): |
|
return json.dumps({"error": "Wikipedia query cannot be empty"}, indent=2) |
|
success = insert_data_from_wikipedia(wiki_query) |
|
if not success: |
|
return json.dumps({"error": f"No data found for Wikipedia query: {wiki_query}"}, indent=2) |
|
if not cypher_query.strip(): |
|
return json.dumps({"error": "Cypher query cannot be empty"}, indent=2) |
|
result = query_neo4j(cypher_query) |
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=["text", "text"], |
|
outputs="json", |
|
title="Neo4j and Wikipedia Interface", |
|
description="Insert data from Wikipedia and query the Neo4j database." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|