binhnase04854's picture
first deploy
b699122
"""Token counter function."""
import asyncio
import logging
from contextlib import contextmanager
from typing import Any, Callable
from gpt_index.indices.service_context import ServiceContext
logger = logging.getLogger(__name__)
def llm_token_counter(method_name_str: str) -> Callable:
"""
Use this as a decorator for methods in index/query classes that make calls to LLMs.
At the moment, this decorator can only be used on class instance methods with a
`_llm_predictor` attribute.
Do not use this on abstract methods.
For example, consider the class below:
.. code-block:: python
class GPTTreeIndexBuilder:
...
@llm_token_counter("build_from_text")
def build_from_text(self, documents: Sequence[BaseDocument]) -> IndexGraph:
...
If you run `build_from_text()`, it will print the output in the form below:
```
[build_from_text] Total token usage: <some-number> tokens
```
"""
def wrap(f: Callable) -> Callable:
@contextmanager
def wrapper_logic(_self: Any) -> Any:
service_context = getattr(_self, "_service_context", None)
if not isinstance(service_context, ServiceContext):
raise ValueError(
"Cannot use llm_token_counter on an instance "
"without a service context."
)
llm_predictor = service_context.llm_predictor
embed_model = service_context.embed_model
start_token_ct = llm_predictor.total_tokens_used
start_embed_token_ct = embed_model.total_tokens_used
yield
net_tokens = llm_predictor.total_tokens_used - start_token_ct
llm_predictor.last_token_usage = net_tokens
net_embed_tokens = embed_model.total_tokens_used - start_embed_token_ct
embed_model.last_token_usage = net_embed_tokens
# print outputs
logger.info(
f"> [{method_name_str}] Total LLM token usage: {net_tokens} tokens"
)
logger.info(
f"> [{method_name_str}] Total embedding token usage: "
f"{net_embed_tokens} tokens"
)
async def wrapped_async_llm_predict(
_self: Any, *args: Any, **kwargs: Any
) -> Any:
with wrapper_logic(_self):
f_return_val = await f(_self, *args, **kwargs)
return f_return_val
def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any:
with wrapper_logic(_self):
f_return_val = f(_self, *args, **kwargs)
return f_return_val
if asyncio.iscoroutinefunction(f):
return wrapped_async_llm_predict
else:
return wrapped_llm_predict
return wrap