|
from typing import Any, Optional |
|
from langchain.chains import LLMChain |
|
from langchain.base_language import BaseLanguageModel |
|
from langchain.prompts import PromptTemplate |
|
from langchain.memory.chat_memory import BaseMemory |
|
from models import llm |
|
|
|
from promopts import CONTENT_RE_WRIGHT_PROMPT, FEEDBACK_PROMPT |
|
|
|
|
|
class HumanFeedBackChain(LLMChain): |
|
"""Chain to run queries against LLMs.""" |
|
|
|
memory: Optional[BaseMemory] = None |
|
|
|
def __init__(self, verbose=True, llm: BaseLanguageModel = llm(temperature=0.7), memory: Optional[BaseMemory] = None, prompt: PromptTemplate = FEEDBACK_PROMPT): |
|
super().__init__(llm=llm, prompt=prompt, memory=memory, verbose=verbose) |
|
|
|
def run(self, *args: Any, **kwargs: Any) -> str: |
|
"""Run the chain as text in, text out or multiple variables, text out.""" |
|
if len(self.output_keys) != 1: |
|
raise ValueError( |
|
f"`run` not supported when there is not exactly " |
|
f"one output key. Got {self.output_keys}." |
|
) |
|
|
|
if args and not kwargs: |
|
if len(args) != 1: |
|
raise ValueError( |
|
"`run` supports only one positional argument.") |
|
return self("Answer:" + args[0])[self.output_keys[0]] |
|
|
|
if kwargs and not args: |
|
return self(kwargs)[self.output_keys[0]] |
|
|
|
raise ValueError( |
|
f"`run` supported with either positional arguments or keyword arguments" |
|
f" but not both. Got args: {args} and kwargs: {kwargs}." |
|
) |
|
|
|
|
|
contextRewriteChain = LLMChain(llm=llm(temperature=0.7), prompt=CONTENT_RE_WRIGHT_PROMPT) |