Spaces:
Runtime error
Runtime error
File size: 5,412 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""Experiment with different indices, models, and more."""
from __future__ import annotations
import time
from typing import Any, Dict, List, Optional, Type, Union
import pandas as pd
from langchain.input import get_color_mapping, print_text
from gpt_index.indices.base import BaseGPTIndex
from gpt_index.indices.list.base import GPTListIndex
from gpt_index.indices.tree.base import GPTTreeIndex
from gpt_index.indices.vector_store import GPTSimpleVectorIndex
from gpt_index.readers.schema.base import Document
DEFAULT_INDEX_CLASSES = [GPTSimpleVectorIndex, GPTTreeIndex, GPTListIndex]
DEFAULT_MODES = ["default", "summarize", "embedding", "retrieve", "recursive"]
class Playground:
"""Experiment with indices, models, embeddings, modes, and more."""
def __init__(self, indices: List[BaseGPTIndex], modes: List[str] = DEFAULT_MODES):
"""Initialize with indices to experiment with.
Args:
indices: A list of BaseGPTIndex's to experiment with
modes: A list of modes that specify which nodes are chosen
from the index when a query is made. A full list of modes
available to each index can be found here:
https://gpt-index.readthedocs.io/en/latest/reference/query.html
"""
self._validate_indices(indices)
self._indices = indices
self._validate_modes(modes)
self._modes = modes
index_range = [str(i) for i in range(len(indices))]
self.index_colors = get_color_mapping(index_range)
@classmethod
def from_docs(
cls,
documents: List[Document],
index_classes: List[Type[BaseGPTIndex]] = DEFAULT_INDEX_CLASSES,
**kwargs: Any,
) -> Playground:
"""Initialize with Documents using the default list of indices.
Args:
documents: A List of Documents to experiment with.
"""
if len(documents) == 0:
raise ValueError(
"Playground must be initialized with a nonempty list of Documents."
)
indices = [index_class(documents) for index_class in index_classes]
return cls(indices, **kwargs)
def _validate_indices(self, indices: List[BaseGPTIndex]) -> None:
"""Validate a list of indices."""
if len(indices) == 0:
raise ValueError("Playground must have a non-empty list of indices.")
for index in indices:
if not isinstance(index, BaseGPTIndex):
raise ValueError(
"Every index in Playground should be an instance of BaseGPTIndex."
)
@property
def indices(self) -> List[BaseGPTIndex]:
"""Get Playground's indices."""
return self._indices
@indices.setter
def indices(self, indices: List[BaseGPTIndex]) -> None:
"""Set Playground's indices."""
self._validate_indices(indices)
self._indices = indices
def _validate_modes(self, modes: List[str]) -> None:
"""Validate a list of modes."""
if len(modes) == 0:
raise ValueError(
"Playground must have a nonzero number of modes."
"Initialize without the `modes` argument to use the default list."
)
@property
def modes(self) -> List[str]:
"""Get Playground's indices."""
return self._modes
@modes.setter
def modes(self, modes: List[str]) -> None:
"""Set Playground's indices."""
self._validate_modes(modes)
self._modes = modes
def compare(
self, query_text: str, to_pandas: Optional[bool] = True
) -> Union[pd.DataFrame, List[Dict[str, Any]]]:
"""Compare index outputs on an input query.
Args:
query_text (str): Query to run all indices on.
to_pandas (Optional[bool]): Return results in a pandas dataframe.
True by default.
Returns:
The output of each index along with other data, such as the time it took to
compute. Results are stored in a Pandas Dataframe or a list of Dicts.
"""
print(f"\033[1mQuery:\033[0m\n{query_text}\n")
print(f"Trying {len(self._indices) * len(self._modes)} combinations...\n\n")
result = []
for i, index in enumerate(self._indices):
for mode in self._modes:
if mode not in index.get_query_map():
continue
start_time = time.time()
index_name = type(index).__name__
print_text(f"\033[1m{index_name}\033[0m, mode = {mode}", end="\n")
output = index.query(query_text, mode=mode)
print_text(str(output), color=self.index_colors[str(i)], end="\n\n")
duration = time.time() - start_time
result.append(
{
"Index": index_name,
"Mode": mode,
"Output": str(output),
"Duration": duration,
"LLM Tokens": index.llm_predictor.last_token_usage,
"Embedding Tokens": index.embed_model.last_token_usage,
}
)
print(f"\nRan {len(result)} combinations in total.")
if to_pandas:
return pd.DataFrame(result)
else:
return result
|