File size: 2,358 Bytes
469eae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Handler file for calls to Azure OpenAI's o1/o3 family of models

Written separately to handle faking streaming for o1 and o3 models.
"""

from typing import Any, Callable, Optional, Union

import httpx

from litellm.types.utils import ModelResponse

from ...openai.openai import OpenAIChatCompletion
from ..common_utils import BaseAzureLLM


class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
    def completion(
        self,
        model_response: ModelResponse,
        timeout: Union[float, httpx.Timeout],
        optional_params: dict,
        litellm_params: dict,
        logging_obj: Any,
        model: Optional[str] = None,
        messages: Optional[list] = None,
        print_verbose: Optional[Callable] = None,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        dynamic_params: Optional[bool] = None,
        azure_ad_token: Optional[str] = None,
        acompletion: bool = False,
        logger_fn=None,
        headers: Optional[dict] = None,
        custom_prompt_dict: dict = {},
        client=None,
        organization: Optional[str] = None,
        custom_llm_provider: Optional[str] = None,
        drop_params: Optional[bool] = None,
    ):
        client = self.get_azure_openai_client(
            litellm_params=litellm_params,
            api_key=api_key,
            api_base=api_base,
            api_version=api_version,
            client=client,
            _is_async=acompletion,
        )
        return super().completion(
            model_response=model_response,
            timeout=timeout,
            optional_params=optional_params,
            litellm_params=litellm_params,
            logging_obj=logging_obj,
            model=model,
            messages=messages,
            print_verbose=print_verbose,
            api_key=api_key,
            api_base=api_base,
            api_version=api_version,
            dynamic_params=dynamic_params,
            azure_ad_token=azure_ad_token,
            acompletion=acompletion,
            logger_fn=logger_fn,
            headers=headers,
            custom_prompt_dict=custom_prompt_dict,
            client=client,
            organization=organization,
            custom_llm_provider=custom_llm_provider,
            drop_params=drop_params,
        )