Gary0205 commited on
Commit
6c4bf7b
1 Parent(s): b932eb5

Create insertData.py

Browse files
Files changed (1) hide show
  1. insertData.py +130 -0
insertData.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.graphs.neo4j_graph import Neo4jGraph
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_community.document_loaders import WikipediaLoader
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+
6
+ url = "neo4j+s://ddb8863b.databases.neo4j.io"
7
+ username = "neo4j"
8
+ password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
9
+ graph = Neo4jGraph(url=url, username=username, password=password)
10
+
11
+
12
+ text_splitter = RecursiveCharacterTextSplitter(
13
+ chunk_size=512,
14
+ length_function=len,
15
+ is_separator_regex=False,)
16
+
17
+ query = "Dune (Frank Herbert)"
18
+ raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)
19
+
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
22
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
23
+
24
+ def extract_relations_from_model_output(text):
25
+ relations = []
26
+ relation, subject, relation, object_ = '', '', '', ''
27
+ text = text.strip()
28
+ current = 'x'
29
+ text_replaced = text.replace("", "").replace("", "").replace("", "")
30
+ for token in text_replaced.split():
31
+ if token == "":
32
+ current = 't'
33
+ if relation != '':
34
+ relations.append({
35
+ 'head': subject.strip(),
36
+ 'type': relation.strip(),
37
+ 'tail': object_.strip()
38
+ })
39
+ relation = ''
40
+ subject = ''
41
+ elif token == "":
42
+ current = 's'
43
+ if relation != '':
44
+ relations.append({
45
+ 'head': subject.strip(),
46
+ 'type': relation.strip(),
47
+ 'tail': object_.strip()
48
+ })
49
+ object_ = ''
50
+ elif token == "":
51
+ current = 'o'
52
+ relation = ''
53
+ else:
54
+ if current == 't':
55
+ subject += ' ' + token
56
+ elif current == 's':
57
+ object_ += ' ' + token
58
+ elif current == 'o':
59
+ relation += ' ' + token
60
+ if subject != '' and relation != '' and object_ != '':
61
+ relations.append({
62
+ 'head': subject.strip(),
63
+ 'type': relation.strip(),
64
+ 'tail': object_.strip()
65
+ })
66
+ return relations
67
+
68
+ class KB():
69
+ def __init__(self):
70
+ self.relations = []
71
+ def are_relations_equal(self, r1, r2):
72
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
73
+ def exists_relation(self, r1):
74
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
75
+ def add_relation(self, r):
76
+ if not self.exists_relation(r):
77
+ self.relations.append(r)
78
+ def print(self):
79
+ print("Relations:")
80
+ for r in self.relations:
81
+ print(f" {r}")
82
+
83
+ def from_small_text_to_kb(text, verbose=False):
84
+ kb = KB()
85
+
86
+ # Tokenizer text
87
+ model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True,return_tensors='pt')
88
+
89
+ if verbose:
90
+ print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
91
+
92
+ # Generate
93
+ gen_kwargs = {
94
+ "max_length": 216,
95
+ "length_penalty": 0,
96
+ "num_beams": 3,
97
+ "num_return_sequences": 3
98
+ }
99
+
100
+ generated_tokens = model.generate(
101
+ **model_inputs,
102
+ **gen_kwargs,
103
+ )
104
+
105
+ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
106
+
107
+ # create kb
108
+ for sentence_pred in decoded_preds:
109
+ relations = extract_relations_from_model_output(sentence_pred)
110
+ for r in relations:
111
+ kb.add_relation(r)
112
+
113
+ return kb
114
+
115
+ for doc in raw_documents:
116
+ kb = from_small_text_to_kb(doc.page_content, verbose=True)
117
+
118
+
119
+ for relation in kb.relations:
120
+ head = relation['head']
121
+ relationship = relation['type']
122
+ tail = relation['tail']
123
+
124
+
125
+ cypher = f"MERGE (h:`{head}`)" + f" MERGE (t:`{tail}`)" + f" MERGE (h)-[:`{relationship}`]->(t)"
126
+ print(cypher)
127
+ graph.query(cypher)
128
+
129
+
130
+ graph.refresh_schema()