File size: 2,416 Bytes
ee3a625
 
90abc4b
ee3a625
 
 
 
 
 
 
90abc4b
 
 
 
 
 
 
 
 
 
 
ee3a625
 
 
 
 
90abc4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee3a625
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from typing import List, Optional
from langchain import PromptTemplate

from langchain.chains.base import Chain
from langchain.chains.summarize import load_summarize_chain

from app_modules.llm_inference import LLMInference


def get_llama_2_prompt_template(instruction):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

    system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the text to get context"

    SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
    prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
    return prompt_template


class SummarizeChain(LLMInference):
    def __init__(self, llm_loader):
        super().__init__(llm_loader)

    def create_chain(self) -> Chain:
        use_llama_2_prompt_template = (
            os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
        )
        prompt_template = """Write a concise summary of the following:
{text}
CONCISE SUMMARY:"""

        if use_llama_2_prompt_template:
            prompt_template = get_llama_2_prompt_template(prompt_template)
        prompt = PromptTemplate.from_template(prompt_template)

        refine_template = (
            "Your job is to produce a final summary\n"
            "We have provided an existing summary up to a certain point: {existing_answer}\n"
            "We have the opportunity to refine the existing summary"
            "(only if needed) with some more context below.\n"
            "------------\n"
            "{text}\n"
            "------------\n"
            "Given the new context, refine the original summary."
            "If the context isn't useful, return the original summary."
        )

        if use_llama_2_prompt_template:
            refine_template = get_llama_2_prompt_template(refine_template)
        refine_prompt = PromptTemplate.from_template(refine_template)

        chain = load_summarize_chain(
            llm=self.llm_loader.llm,
            chain_type="refine",
            question_prompt=prompt,
            refine_prompt=refine_prompt,
            return_intermediate_steps=True,
            input_key="input_documents",
            output_key="output_text",
        )
        return chain

    def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
        result = chain(inputs, return_only_outputs=True)
        return result