File size: 5,412 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Experiment with different indices, models, and more."""
from __future__ import annotations

import time
from typing import Any, Dict, List, Optional, Type, Union

import pandas as pd
from langchain.input import get_color_mapping, print_text

from gpt_index.indices.base import BaseGPTIndex
from gpt_index.indices.list.base import GPTListIndex
from gpt_index.indices.tree.base import GPTTreeIndex
from gpt_index.indices.vector_store import GPTSimpleVectorIndex
from gpt_index.readers.schema.base import Document

DEFAULT_INDEX_CLASSES = [GPTSimpleVectorIndex, GPTTreeIndex, GPTListIndex]
DEFAULT_MODES = ["default", "summarize", "embedding", "retrieve", "recursive"]


class Playground:
    """Experiment with indices, models, embeddings, modes, and more."""

    def __init__(self, indices: List[BaseGPTIndex], modes: List[str] = DEFAULT_MODES):
        """Initialize with indices to experiment with.

        Args:
            indices: A list of BaseGPTIndex's to experiment with
            modes: A list of modes that specify which nodes are chosen
                from the index when a query is made. A full list of modes
                available to each index can be found here:
                https://gpt-index.readthedocs.io/en/latest/reference/query.html
        """
        self._validate_indices(indices)
        self._indices = indices
        self._validate_modes(modes)
        self._modes = modes

        index_range = [str(i) for i in range(len(indices))]
        self.index_colors = get_color_mapping(index_range)

    @classmethod
    def from_docs(
        cls,
        documents: List[Document],
        index_classes: List[Type[BaseGPTIndex]] = DEFAULT_INDEX_CLASSES,
        **kwargs: Any,
    ) -> Playground:
        """Initialize with Documents using the default list of indices.

        Args:
            documents: A List of Documents to experiment with.
        """
        if len(documents) == 0:
            raise ValueError(
                "Playground must be initialized with a nonempty list of Documents."
            )

        indices = [index_class(documents) for index_class in index_classes]
        return cls(indices, **kwargs)

    def _validate_indices(self, indices: List[BaseGPTIndex]) -> None:
        """Validate a list of indices."""
        if len(indices) == 0:
            raise ValueError("Playground must have a non-empty list of indices.")
        for index in indices:
            if not isinstance(index, BaseGPTIndex):
                raise ValueError(
                    "Every index in Playground should be an instance of BaseGPTIndex."
                )

    @property
    def indices(self) -> List[BaseGPTIndex]:
        """Get Playground's indices."""
        return self._indices

    @indices.setter
    def indices(self, indices: List[BaseGPTIndex]) -> None:
        """Set Playground's indices."""
        self._validate_indices(indices)
        self._indices = indices

    def _validate_modes(self, modes: List[str]) -> None:
        """Validate a list of modes."""
        if len(modes) == 0:
            raise ValueError(
                "Playground must have a nonzero number of modes."
                "Initialize without the `modes` argument to use the default list."
            )

    @property
    def modes(self) -> List[str]:
        """Get Playground's indices."""
        return self._modes

    @modes.setter
    def modes(self, modes: List[str]) -> None:
        """Set Playground's indices."""
        self._validate_modes(modes)
        self._modes = modes

    def compare(
        self, query_text: str, to_pandas: Optional[bool] = True
    ) -> Union[pd.DataFrame, List[Dict[str, Any]]]:
        """Compare index outputs on an input query.

        Args:
            query_text (str): Query to run all indices on.
            to_pandas (Optional[bool]): Return results in a pandas dataframe.
                True by default.

        Returns:
            The output of each index along with other data, such as the time it took to
            compute. Results are stored in a Pandas Dataframe or a list of Dicts.
        """
        print(f"\033[1mQuery:\033[0m\n{query_text}\n")
        print(f"Trying {len(self._indices) * len(self._modes)} combinations...\n\n")
        result = []
        for i, index in enumerate(self._indices):
            for mode in self._modes:
                if mode not in index.get_query_map():
                    continue
                start_time = time.time()

                index_name = type(index).__name__
                print_text(f"\033[1m{index_name}\033[0m, mode = {mode}", end="\n")
                output = index.query(query_text, mode=mode)
                print_text(str(output), color=self.index_colors[str(i)], end="\n\n")

                duration = time.time() - start_time

                result.append(
                    {
                        "Index": index_name,
                        "Mode": mode,
                        "Output": str(output),
                        "Duration": duration,
                        "LLM Tokens": index.llm_predictor.last_token_usage,
                        "Embedding Tokens": index.embed_model.last_token_usage,
                    }
                )
        print(f"\nRan {len(result)} combinations in total.")

        if to_pandas:
            return pd.DataFrame(result)
        else:
            return result