Spaces:
Runtime error
Runtime error
File size: 2,623 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
"""Token counter function."""
import logging
from typing import Any, Callable, cast
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
def llm_token_counter(method_name_str: str) -> Callable:
"""
Use this as a decorator for methods in index/query classes that make calls to LLMs.
At the moment, this decorator can only be used on class instance methods with a
`_llm_predictor` attribute.
Do not use this on abstract methods.
For example, consider the class below:
.. code-block:: python
class GPTTreeIndexBuilder:
...
@llm_token_counter("build_from_text")
def build_from_text(self, documents: Sequence[BaseDocument]) -> IndexGraph:
...
If you run `build_from_text()`, it will print the output in the form below:
```
[build_from_text] Total token usage: <some-number> tokens
```
"""
def wrap(f: Callable) -> Callable:
def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any:
llm_predictor = getattr(_self, "_llm_predictor", None)
if llm_predictor is None:
raise ValueError(
"Cannot use llm_token_counter on an instance "
"without a _llm_predictor attribute."
)
llm_predictor = cast(LLMPredictor, llm_predictor)
embed_model = getattr(_self, "_embed_model", None)
if embed_model is None:
raise ValueError(
"Cannot use llm_token_counter on an instance "
"without a _embed_model attribute."
)
embed_model = cast(BaseEmbedding, embed_model)
start_token_ct = llm_predictor.total_tokens_used
start_embed_token_ct = embed_model.total_tokens_used
f_return_val = f(_self, *args, **kwargs)
net_tokens = llm_predictor.total_tokens_used - start_token_ct
llm_predictor.last_token_usage = net_tokens
net_embed_tokens = embed_model.total_tokens_used - start_embed_token_ct
embed_model.last_token_usage = net_embed_tokens
# print outputs
logging.info(
f"> [{method_name_str}] Total LLM token usage: {net_tokens} tokens"
)
logging.info(
f"> [{method_name_str}] Total embedding token usage: "
f"{net_embed_tokens} tokens"
)
return f_return_val
return wrapped_llm_predict
return wrap
|