Spaces:
Runtime error
Runtime error
File size: 3,230 Bytes
58d33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
"""Experiment with different models."""
from __future__ import annotations
from typing import List, Optional, Sequence
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
class ModelLaboratory:
"""Experiment with different models."""
def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
"""Initialize with chains to experiment with.
Args:
chains: list of chains to experiment with.
"""
for chain in chains:
if not isinstance(chain, Chain):
raise ValueError(
"ModelLaboratory should now be initialized with Chains. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
self.names = names
@classmethod
def from_llms(
cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
) -> ModelLaboratory:
"""Initialize with LLMs to experiment with and optional prompt.
Args:
llms: list of LLMs to experiment with
prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
If a prompt was provided, it should only have one input variable.
"""
if prompt is None:
prompt = PromptTemplate(input_variables=["_input"], template="{_input}")
chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
names = [str(llm) for llm in llms]
return cls(chains, names=names)
def compare(self, text: str) -> None:
"""Compare model outputs on an input text.
If a prompt was provided with starting the laboratory, then this text will be
fed into the prompt. If no prompt was provided, then the input text is the
entire prompt.
Args:
text: input text to run all models on.
"""
print(f"\033[1mInput:\033[0m\n{text}\n")
for i, chain in enumerate(self.chains):
if self.names is not None:
name = self.names[i]
else:
name = str(chain)
print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
|