File size: 4,221 Bytes
6c4bf7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WikipediaLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
url = "neo4j+s://ddb8863b.databases.neo4j.io"
username = "neo4j"
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
graph = Neo4jGraph(url=url, username=username, password=password)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512,
length_function=len,
is_separator_regex=False,)
query = "Dune (Frank Herbert)"
raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("", "").replace("", "").replace("", "")
for token in text_replaced.split():
if token == "":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif token == "":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
object_ = ''
elif token == "":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
class KB():
def __init__(self):
self.relations = []
def are_relations_equal(self, r1, r2):
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
def exists_relation(self, r1):
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
def add_relation(self, r):
if not self.exists_relation(r):
self.relations.append(r)
def print(self):
print("Relations:")
for r in self.relations:
print(f" {r}")
def from_small_text_to_kb(text, verbose=False):
kb = KB()
# Tokenizer text
model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True,return_tensors='pt')
if verbose:
print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
# Generate
gen_kwargs = {
"max_length": 216,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": 3
}
generated_tokens = model.generate(
**model_inputs,
**gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
# create kb
for sentence_pred in decoded_preds:
relations = extract_relations_from_model_output(sentence_pred)
for r in relations:
kb.add_relation(r)
return kb
for doc in raw_documents:
kb = from_small_text_to_kb(doc.page_content, verbose=True)
for relation in kb.relations:
head = relation['head']
relationship = relation['type']
tail = relation['tail']
cypher = f"MERGE (h:`{head}`)" + f" MERGE (t:`{tail}`)" + f" MERGE (h)-[:`{relationship}`]->(t)"
print(cypher)
graph.query(cypher)
graph.refresh_schema() |