File size: 730 Bytes
c6180c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import List

from langchain.indexes.graph import *
from langchain.indexes.graph import GraphIndexCreator as OriginalGraphIndexCreator


class GraphIndexCreator(OriginalGraphIndexCreator):
    def from_texts(self, texts: List[str]) -> NetworkxEntityGraph:
        """Create graph index from text."""
        if self.llm is None:
            raise ValueError("llm should not be None")
        graph = self.graph_type()
        chain = LLMChain(llm=self.llm, prompt=KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT)

        for text in texts:
            output = chain.predict(text=text)
            knowledge = parse_triples(output)
            for triple in knowledge:
                graph.add_triple(triple)
        return graph