from langchain_community.graphs.neo4j_graph import Neo4jGraph from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WikipediaLoader from transformers import AutoModelForSeq2SeqLM, AutoTokenizer url = "neo4j+s://ddb8863b.databases.neo4j.io" username = "neo4j" password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0" graph = Neo4jGraph(url=url, username=username, password=password) text_splitter = RecursiveCharacterTextSplitter( chunk_size=512, length_function=len, is_separator_regex=False,) query = "Dune (Frank Herbert)" raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter) tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") def extract_relations_from_model_output(text): relations = [] relation, subject, relation, object_ = '', '', '', '' text = text.strip() current = 'x' text_replaced = text.replace("", "").replace("", "").replace("", "") for token in text_replaced.split(): if token == "": current = 't' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) relation = '' subject = '' elif token == "": current = 's' if relation != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) object_ = '' elif token == "": current = 'o' relation = '' else: if current == 't': subject += ' ' + token elif current == 's': object_ += ' ' + token elif current == 'o': relation += ' ' + token if subject != '' and relation != '' and object_ != '': relations.append({ 'head': subject.strip(), 'type': relation.strip(), 'tail': object_.strip() }) return relations class KB(): def __init__(self): self.relations = [] def are_relations_equal(self, r1, r2): return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"]) def exists_relation(self, r1): return any(self.are_relations_equal(r1, r2) for r2 in self.relations) def add_relation(self, r): if not self.exists_relation(r): self.relations.append(r) def print(self): print("Relations:") for r in self.relations: print(f" {r}") def from_small_text_to_kb(text, verbose=False): kb = KB() # Tokenizer text model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True,return_tensors='pt') if verbose: print(f"Num tokens: {len(model_inputs['input_ids'][0])}") # Generate gen_kwargs = { "max_length": 216, "length_penalty": 0, "num_beams": 3, "num_return_sequences": 3 } generated_tokens = model.generate( **model_inputs, **gen_kwargs, ) decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) # create kb for sentence_pred in decoded_preds: relations = extract_relations_from_model_output(sentence_pred) for r in relations: kb.add_relation(r) return kb for doc in raw_documents: kb = from_small_text_to_kb(doc.page_content, verbose=True) for relation in kb.relations: head = relation['head'] relationship = relation['type'] tail = relation['tail'] cypher = f"MERGE (h:`{head}`)" + f" MERGE (t:`{tail}`)" + f" MERGE (h)-[:`{relationship}`]->(t)" print(cypher) graph.query(cypher) graph.refresh_schema()