File size: 4,221 Bytes
6c4bf7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WikipediaLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

url = "neo4j+s://ddb8863b.databases.neo4j.io"
username = "neo4j"
password = "vz6OLij_IrY-cSIgSMhUWxblTUzH8m4bZaBeJGgmtU0"
graph = Neo4jGraph(url=url, username=username, password=password)


text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    length_function=len, 
    is_separator_regex=False,)

query = "Dune (Frank Herbert)"
raw_documents = WikipediaLoader(query=query).load_and_split(text_splitter=text_splitter)


tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

def extract_relations_from_model_output(text):
    relations = []    
    relation, subject, relation, object_ = '', '', '', ''    
    text = text.strip()    
    current = 'x'    
    text_replaced = text.replace("", "").replace("", "").replace("", "")    
    for token in text_replaced.split():
        if token == "":
            current = 't' 
        if relation != '': 
            relations.append({ 
                'head': subject.strip(),
                'type': relation.strip(),
                        'tail': object_.strip()
                })   
            relation = ''
            subject = ''
        elif token == "":
                current = 's' 
                if relation != '':
                    relations.append({ 
                        'head': subject.strip(), 
                        'type': relation.strip(),  
                        'tail': object_.strip()
                        })   
                    object_ = '' 
        elif token == "": 
                current = 'o'
                relation = ''
        else:   
                if current == 't':
                    subject += ' ' + token  
                elif current == 's':
                    object_ += ' ' + token 
                elif current == 'o': 
                    relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':    
            relations.append({
                'head': subject.strip(), 
                'type': relation.strip(), 
                'tail': object_.strip()  
                    }) 
    return relations

class KB():
    def __init__(self):
        self.relations = []
    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
    def exists_relation(self, r1): 
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
    def add_relation(self, r):  
        if not self.exists_relation(r): 
            self.relations.append(r) 
    def print(self):  
        print("Relations:") 
        for r in self.relations: 
            print(f"  {r}")

def from_small_text_to_kb(text, verbose=False):
    kb = KB()  

    # Tokenizer text
    model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True,return_tensors='pt')

    if verbose: 
        print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
        
        # Generate   
    gen_kwargs = { 
        "max_length": 216,
        "length_penalty": 0, 
        "num_beams": 3,  
        "num_return_sequences": 3
        } 
        
    generated_tokens = model.generate( 
        **model_inputs, 
        **gen_kwargs, 
        )  
        
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)  

    # create kb 
    for sentence_pred in decoded_preds: 
        relations = extract_relations_from_model_output(sentence_pred) 
        for r in relations: 
            kb.add_relation(r) 
        
    return kb

for doc in raw_documents:
    kb = from_small_text_to_kb(doc.page_content, verbose=True)


    for relation in kb.relations:
        head = relation['head']
        relationship = relation['type']
        tail = relation['tail']


        cypher = f"MERGE (h:`{head}`)" + f" MERGE (t:`{tail}`)" + f" MERGE (h)-[:`{relationship}`]->(t)"
        print(cypher)
        graph.query(cypher)


graph.refresh_schema()