File size: 2,623 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Token counter function."""

import logging
from typing import Any, Callable, cast

from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor


def llm_token_counter(method_name_str: str) -> Callable:
    """
    Use this as a decorator for methods in index/query classes that make calls to LLMs.

    At the moment, this decorator can only be used on class instance methods with a
    `_llm_predictor` attribute.

    Do not use this on abstract methods.

    For example, consider the class below:
        .. code-block:: python
            class GPTTreeIndexBuilder:
            ...
            @llm_token_counter("build_from_text")
            def build_from_text(self, documents: Sequence[BaseDocument]) -> IndexGraph:
                ...

    If you run `build_from_text()`, it will print the output in the form below:

    ```
    [build_from_text] Total token usage: <some-number> tokens
    ```
    """

    def wrap(f: Callable) -> Callable:
        def wrapped_llm_predict(_self: Any, *args: Any, **kwargs: Any) -> Any:
            llm_predictor = getattr(_self, "_llm_predictor", None)
            if llm_predictor is None:
                raise ValueError(
                    "Cannot use llm_token_counter on an instance "
                    "without a _llm_predictor attribute."
                )
            llm_predictor = cast(LLMPredictor, llm_predictor)

            embed_model = getattr(_self, "_embed_model", None)
            if embed_model is None:
                raise ValueError(
                    "Cannot use llm_token_counter on an instance "
                    "without a _embed_model attribute."
                )
            embed_model = cast(BaseEmbedding, embed_model)

            start_token_ct = llm_predictor.total_tokens_used
            start_embed_token_ct = embed_model.total_tokens_used

            f_return_val = f(_self, *args, **kwargs)

            net_tokens = llm_predictor.total_tokens_used - start_token_ct
            llm_predictor.last_token_usage = net_tokens
            net_embed_tokens = embed_model.total_tokens_used - start_embed_token_ct
            embed_model.last_token_usage = net_embed_tokens

            # print outputs
            logging.info(
                f"> [{method_name_str}] Total LLM token usage: {net_tokens} tokens"
            )
            logging.info(
                f"> [{method_name_str}] Total embedding token usage: "
                f"{net_embed_tokens} tokens"
            )

            return f_return_val

        return wrapped_llm_predict

    return wrap