import gradio as gr from py2neo import Graph from langchain_community.graphs.neo4j_graph import Neo4jGraph from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WikipediaLoader from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import json # Set up the connection to the Neo4j database url = "neo4j+s://ddb8863b.databases.neo4j.io" username = "neo4j" password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0" graph = Graph(url, auth=(username, password)) neo4j_graph = Neo4jGraph(url=url, username=username, password=password) # Initialize the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") # Function to extract relations from model output def extract_relations_from_model_output(text): triplets = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' for token in text.replace("", "").replace("", "").replace("", "").split(): if token == "": current = 't' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) relation = '' subject = '' elif token == "": current = 's' if relation != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) return triplets extracted_triplets = extract_triplets(extracted_text[0]) print(extracted_triplets) return extracted_triplets class KB(): def __init__(self): self.relations = [] def are_relations_equal(self, r1, r2): return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) def exists_relation(self, r1): return any(self.are_relations_equal(r1, r2) for r2 in self.relations) def add_relation(self, r): if not self.exists_relation(r): self.relations.append(r) def print(self): print("Relations:") for r in self.relations: print(f" {r}") def from_small_text_to_kb(text, verbose=False): kb = KB() model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt') if verbose: print(f"Num tokens: {len(model_inputs['input_ids'][0])}") print("Tokens are done") gen_kwargs = { "max_length": 216, "length_penalty": 0, "num_beams": 3, "num_return_sequences": 3 } generated_tokens = model.generate( **model_inputs, **gen_kwargs, ) decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) print("Before for loop") for sentence_pred in decoded_preds: relations = extract_relations_from_model_output(sentence_pred) print(len(relations)) for r in relations: kb.add_relation(r) return kb # Function to insert data into Neo4j from Wikipedia query def insert_data_from_wikipedia(query): text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, length_function=len, is_separator_regex=False) raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter) if not raw_documents: print("No documents found for query:", query) return False for doc in raw_documents: kb = from_small_text_to_kb(doc.page_content, verbose=True) for relation in kb.relations: head = relation['head'] relationship = relation['type'] tail = relation['tail'] if head and relationship and tail: cypher = f"MERGE (h:`{head}`) MERGE (t:`{tail}`) MERGE (h)-[:`{relationship}`]->(t)" print(f"Executing Cypher query: {cypher}") # Debug print for Cypher query graph.run(cypher) else: print(f"Skipping invalid relation: head='{head}', relationship='{relationship}', tail='{tail}'") # Skip invalid relations return True # Function to query the database def query_neo4j(query): if not query.strip(): return json.dumps({"error": "Empty Cypher query"}, indent=2) # Handle empty query case try: result = graph.run(query).data() return json.dumps(result, indent=2) # Convert to JSON string except Exception as e: return json.dumps({"error": str(e)}, indent=2) # Return error as JSON # Gradio interface function def gradio_interface(wiki_query, cypher_query): if not wiki_query.strip(): return json.dumps({"error": "Wikipedia query cannot be empty"}, indent=2) success = insert_data_from_wikipedia(wiki_query) if not success: return json.dumps({"error": f"No data found for Wikipedia query: {wiki_query}"}, indent=2) if not cypher_query.strip(): return json.dumps({"error": "Cypher query cannot be empty"}, indent=2) result = query_neo4j(cypher_query) return result # Create the Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=["text", "text"], outputs="json", title="Neo4j and Wikipedia Interface", description="Insert data from Wikipedia and query the Neo4j database." ) # Launch the Gradio app if __name__ == "__main__": iface.launch()