File size: 4,674 Bytes
469eae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Transformation logic from Cohere's /v1/rerank format to Jina AI's  `/v1/rerank` format. 

Why separate file? Make it easy to see how transformation works

Docs - https://jina.ai/reranker
"""

import uuid
from typing import Any, Dict, List, Optional, Tuple, Union

from httpx import URL, Response

from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.types.rerank import (
    OptionalRerankParams,
    RerankBilledUnits,
    RerankResponse,
    RerankResponseMeta,
    RerankTokens,
)
from litellm.types.utils import ModelInfo


class JinaAIRerankConfig(BaseRerankConfig):
    def get_supported_cohere_rerank_params(self, model: str) -> list:
        return [
            "query",
            "top_n",
            "documents",
            "return_documents",
        ]

    def map_cohere_rerank_params(
        self,
        non_default_params: dict,
        model: str,
        drop_params: bool,
        query: str,
        documents: List[Union[str, Dict[str, Any]]],
        custom_llm_provider: Optional[str] = None,
        top_n: Optional[int] = None,
        rank_fields: Optional[List[str]] = None,
        return_documents: Optional[bool] = True,
        max_chunks_per_doc: Optional[int] = None,
        max_tokens_per_doc: Optional[int] = None,
    ) -> OptionalRerankParams:
        optional_params = {}
        supported_params = self.get_supported_cohere_rerank_params(model)
        for k, v in non_default_params.items():
            if k in supported_params:
                optional_params[k] = v
        return OptionalRerankParams(
            **optional_params,
        )

    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
        base_path = "/v1/rerank"

        if api_base is None:
            return "https://api.jina.ai/v1/rerank"
        base = URL(api_base)
        # Reconstruct URL with cleaned path
        cleaned_base = str(base.copy_with(path=base_path))

        return cleaned_base

    def transform_rerank_request(
        self, model: str, optional_rerank_params: OptionalRerankParams, headers: Dict
    ) -> Dict:
        return {"model": model, **optional_rerank_params}

    def transform_rerank_response(
        self,
        model: str,
        raw_response: Response,
        model_response: RerankResponse,
        logging_obj: LiteLLMLoggingObj,
        api_key: Optional[str] = None,
        request_data: Dict = {},
        optional_params: Dict = {},
        litellm_params: Dict = {},
    ) -> RerankResponse:
        if raw_response.status_code != 200:
            raise Exception(raw_response.text)

        logging_obj.post_call(original_response=raw_response.text)

        _json_response = raw_response.json()

        _billed_units = RerankBilledUnits(**_json_response.get("usage", {}))
        _tokens = RerankTokens(**_json_response.get("usage", {}))
        rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)

        _results: Optional[List[dict]] = _json_response.get("results")

        if _results is None:
            raise ValueError(f"No results found in the response={_json_response}")

        return RerankResponse(
            id=_json_response.get("id") or str(uuid.uuid4()),
            results=_results,  # type: ignore
            meta=rerank_meta,
        )  # Return response

    def validate_environment(
        self, headers: Dict, model: str, api_key: Optional[str] = None
    ) -> Dict:
        if api_key is None:
            raise ValueError(
                "api_key is required. Set via `api_key` parameter or `JINA_API_KEY` environment variable."
            )
        return {
            "accept": "application/json",
            "content-type": "application/json",
            "authorization": f"Bearer {api_key}",
        }

    def calculate_rerank_cost(
        self,
        model: str,
        custom_llm_provider: Optional[str] = None,
        billed_units: Optional[RerankBilledUnits] = None,
        model_info: Optional[ModelInfo] = None,
    ) -> Tuple[float, float]:
        """
        Jina AI reranker is priced at $0.000000018 per token.
        """
        if (
            model_info is None
            or "input_cost_per_token" not in model_info
            or model_info["input_cost_per_token"] is None
            or billed_units is None
        ):
            return 0.0, 0.0

        total_tokens = billed_units.get("total_tokens")
        if total_tokens is None:
            return 0.0, 0.0

        input_cost = model_info["input_cost_per_token"] * total_tokens
        return input_cost, 0.0