kgraph / insertData.py
Gary0205's picture
Create insertData.py
6c4bf7b verified
raw
history blame contribute delete
No virus
4.22 kB
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WikipediaLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
url = "neo4j+s://ddb8863b.databases.neo4j.io"
username = "neo4j"
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
graph = Neo4jGraph(url=url, username=username, password=password)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512,
length_function=len,
is_separator_regex=False,)
query = "Dune (Frank Herbert)"
raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("", "").replace("", "").replace("", "")
for token in text_replaced.split():
if token == "":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif token == "":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
object_ = ''
elif token == "":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
class KB():
def __init__(self):
self.relations = []
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def add_relation(self, r):
if not self.exists_relation(r):
self.relations.append(r)
def print(self):
print("Relations:")
for r in self.relations:
print(f" {r}")
def from_small_text_to_kb(text, verbose=False):
kb = KB()
# Tokenizer text
model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True,return_tensors='pt')
if verbose:
print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
# Generate
gen_kwargs = {
"max_length": 216,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": 3
}
generated_tokens = model.generate(
**model_inputs,
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
# create kb
for sentence_pred in decoded_preds:
relations = extract_relations_from_model_output(sentence_pred)
for r in relations:
kb.add_relation(r)
return kb
for doc in raw_documents:
kb = from_small_text_to_kb(doc.page_content, verbose=True)
for relation in kb.relations:
head = relation['head']
relationship = relation['type']
tail = relation['tail']
cypher = f"MERGE (h:`{head}`)" + f" MERGE (t:`{tail}`)" + f" MERGE (h)-[:`{relationship}`]->(t)"
print(cypher)
graph.query(cypher)
graph.refresh_schema()