File size: 2,625 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Question answering over a graph."""
from __future__ import annotations

from typing import Any, Dict, List

from pydantic import Field

from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate


class GraphQAChain(Chain):
    """Chain for question-answering against a graph."""

    graph: NetworkxEntityGraph = Field(exclude=True)
    entity_extraction_chain: LLMChain
    qa_chain: LLMChain
    input_key: str = "query"  #: :meta private:
    output_key: str = "result"  #: :meta private:

    @property
    def input_keys(self) -> List[str]:
        """Return the input keys.

        :meta private:
        """
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """Return the output keys.

        :meta private:
        """
        _output_keys = [self.output_key]
        return _output_keys

    @classmethod
    def from_llm(
        cls,
        llm: BaseLLM,
        qa_prompt: BasePromptTemplate = PROMPT,
        entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
        **kwargs: Any,
    ) -> GraphQAChain:
        """Initialize from LLM."""
        qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
        entity_chain = LLMChain(llm=llm, prompt=entity_prompt)

        return cls(qa_chain=qa_chain, entity_extraction_chain=entity_chain, **kwargs)

    def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
        """Extract entities, look up info and answer question."""
        question = inputs[self.input_key]

        entity_string = self.entity_extraction_chain.run(question)

        self.callback_manager.on_text(
            "Entities Extracted:", end="\n", verbose=self.verbose
        )
        self.callback_manager.on_text(
            entity_string, color="green", end="\n", verbose=self.verbose
        )
        entities = get_entities(entity_string)
        context = ""
        for entity in entities:
            triplets = self.graph.get_entity_knowledge(entity)
            context += "\n".join(triplets)
        self.callback_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
        self.callback_manager.on_text(
            context, color="green", end="\n", verbose=self.verbose
        )
        result = self.qa_chain({"question": question, "context": context})
        return {self.output_key: result[self.qa_chain.output_key]}