Spaces:
Runtime error
Runtime error
"""Chain for applying constitutional principles to the outputs of another chain.""" | |
from typing import Any, Dict, List, Optional | |
from langchain.chains.base import Chain | |
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple | |
from langchain.chains.constitutional_ai.principles import PRINCIPLES | |
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT | |
from langchain.chains.llm import LLMChain | |
from langchain.prompts.base import BasePromptTemplate | |
from langchain.schema import BaseLanguageModel | |
class ConstitutionalChain(Chain): | |
"""Chain for applying constitutional principles. | |
Example: | |
.. code-block:: python | |
from langchain.llms import OpenAI | |
from langchain.chains import LLMChain, ConstitutionalChain | |
qa_prompt = PromptTemplate( | |
template="Q: {question} A:", | |
input_variables=["question"], | |
) | |
qa_chain = LLMChain(llm=OpenAI(), prompt=qa_prompt) | |
constitutional_chain = ConstitutionalChain.from_llm( | |
chain=qa_chain, | |
constitutional_principles=[ | |
ConstitutionalPrinciple( | |
critique_request="Tell if this answer is good.", | |
revision_request="Give a better answer.", | |
) | |
], | |
) | |
constitutional_chain.run(question="What is the meaning of life?") | |
""" | |
chain: LLMChain | |
constitutional_principles: List[ConstitutionalPrinciple] | |
critique_chain: LLMChain | |
revision_chain: LLMChain | |
def get_principles( | |
cls, names: Optional[List[str]] = None | |
) -> List[ConstitutionalPrinciple]: | |
if names is None: | |
return list(PRINCIPLES.values()) | |
else: | |
return [PRINCIPLES[name] for name in names] | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
chain: LLMChain, | |
critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT, | |
revision_prompt: BasePromptTemplate = REVISION_PROMPT, | |
**kwargs: Any, | |
) -> "ConstitutionalChain": | |
"""Create a chain from an LLM.""" | |
critique_chain = LLMChain(llm=llm, prompt=critique_prompt) | |
revision_chain = LLMChain(llm=llm, prompt=revision_prompt) | |
return cls( | |
chain=chain, | |
critique_chain=critique_chain, | |
revision_chain=revision_chain, | |
**kwargs, | |
) | |
def input_keys(self) -> List[str]: | |
"""Defines the input keys.""" | |
return self.chain.input_keys | |
def output_keys(self) -> List[str]: | |
"""Defines the output keys.""" | |
return ["output"] | |
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |
response = self.chain.run(**inputs) | |
input_prompt = self.chain.prompt.format(**inputs) | |
self.callback_manager.on_text( | |
text="Initial response: " + response + "\n\n", | |
verbose=self.verbose, | |
color="yellow", | |
) | |
for constitutional_principle in self.constitutional_principles: | |
# Do critique | |
raw_critique = self.critique_chain.run( | |
input_prompt=input_prompt, | |
output_from_model=response, | |
critique_request=constitutional_principle.critique_request, | |
) | |
critique = self._parse_critique( | |
output_string=raw_critique, | |
).strip() | |
# Do revision | |
revision = self.revision_chain.run( | |
input_prompt=input_prompt, | |
output_from_model=response, | |
critique_request=constitutional_principle.critique_request, | |
critique=critique, | |
revision_request=constitutional_principle.revision_request, | |
).strip() | |
response = revision | |
self.callback_manager.on_text( | |
text=f"Applying {constitutional_principle.name}..." + "\n\n", | |
verbose=self.verbose, | |
color="green", | |
) | |
self.callback_manager.on_text( | |
text="Critique: " + critique + "\n\n", | |
verbose=self.verbose, | |
color="blue", | |
) | |
self.callback_manager.on_text( | |
text="Updated response: " + revision + "\n\n", | |
verbose=self.verbose, | |
color="yellow", | |
) | |
return {"output": response} | |
def _parse_critique(output_string: str) -> str: | |
if "Revision request:" not in output_string: | |
return output_string | |
output_string = output_string.split("Revision request:")[0] | |
if "\n\n" in output_string: | |
output_string = output_string.split("\n\n")[0] | |
return output_string | |