File size: 3,230 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Experiment with different models."""
from __future__ import annotations

from typing import List, Optional, Sequence

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate


class ModelLaboratory:
    """Experiment with different models."""

    def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
        """Initialize with chains to experiment with.

        Args:
            chains: list of chains to experiment with.
        """
        for chain in chains:
            if not isinstance(chain, Chain):
                raise ValueError(
                    "ModelLaboratory should now be initialized with Chains. "
                    "If you want to initialize with LLMs, use the `from_llms` method "
                    "instead (`ModelLaboratory.from_llms(...)`)"
                )
            if len(chain.input_keys) != 1:
                raise ValueError(
                    "Currently only support chains with one input variable, "
                    f"got {chain.input_keys}"
                )
            if len(chain.output_keys) != 1:
                raise ValueError(
                    "Currently only support chains with one output variable, "
                    f"got {chain.output_keys}"
                )
        if names is not None:
            if len(names) != len(chains):
                raise ValueError("Length of chains does not match length of names.")
        self.chains = chains
        chain_range = [str(i) for i in range(len(self.chains))]
        self.chain_colors = get_color_mapping(chain_range)
        self.names = names

    @classmethod
    def from_llms(
        cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
    ) -> ModelLaboratory:
        """Initialize with LLMs to experiment with and optional prompt.

        Args:
            llms: list of LLMs to experiment with
            prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
                If a prompt was provided, it should only have one input variable.
        """
        if prompt is None:
            prompt = PromptTemplate(input_variables=["_input"], template="{_input}")
        chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
        names = [str(llm) for llm in llms]
        return cls(chains, names=names)

    def compare(self, text: str) -> None:
        """Compare model outputs on an input text.

        If a prompt was provided with starting the laboratory, then this text will be
        fed into the prompt. If no prompt was provided, then the input text is the
        entire prompt.

        Args:
            text: input text to run all models on.
        """
        print(f"\033[1mInput:\033[0m\n{text}\n")
        for i, chain in enumerate(self.chains):
            if self.names is not None:
                name = self.names[i]
            else:
                name = str(chain)
            print_text(name, end="\n")
            output = chain.run(text)
            print_text(output, color=self.chain_colors[str(i)], end="\n\n")