File size: 4,743 Bytes
2adb577
 
 
 
 
 
 
 
4535fc2
2adb577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc516fa
2adb577
 
 
 
 
 
27539d2
cc516fa
2adb577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc516fa
 
 
 
 
 
1bfbe04
cc516fa
 
 
 
 
 
2adb577
 
 
 
cc516fa
 
2adb577
 
 
 
 
 
 
 
 
 
 
 
cc516fa
2adb577
 
 
27539d2
cc516fa
2adb577
cc516fa
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import requests
import base64
import uuid
import json
import time
from typing import Dict, Optional, Any
from dotenv import load_dotenv
import os
from src.prompts import giga_system_prompt_i2v

# Load environment variables from .env file
load_dotenv()

AUTH_TOKEN = os.getenv("AUTH_TOKEN")
COOKIE = os.getenv("COOKIE")

# print(f"AUTH_TOKEN: {AUTH_TOKEN}")
# print(f"COOKIE: {COOKIE}")

def get_auth_token(timeout: float = 2) -> Dict[str, Any]:
    """
    Get authentication token.

    Args:
        timeout (float): Timeout duration in seconds.

    Returns:
        Dict[str, Any]: Dictionary containing the access token and its expiration time.
    """
    url = "https://beta.saluteai.sberdevices.ru/v1/token"
    payload = 'scope=GIGACHAT_API_CORP'
    headers = {
        'Content-Type': 'application/x-www-form-urlencoded',
        'Accept': 'application/json',
        'RqUID': str(uuid.uuid4()),
        'Cookie': COOKIE,
        'Authorization': f'Basic {AUTH_TOKEN}'
    }
    response = requests.post(url, headers=headers, data=payload, timeout=timeout)
    response_dict = response.json()
    return {
        'access_token': response_dict['tok'],
        'expires_at': response_dict['exp']
    }

def check_auth_token(token_data: Dict[str, Any]) -> bool:
    """
    Check if the authentication token is valid.

    Args:
        token_data (Dict[str, Any]): Dictionary containing token data.

    Returns:
        bool: True if the token is valid, False otherwise.
    """
    return token_data['expires_at'] - time.time() > 5

token_data: Optional[Dict[str, Any]] = None

def get_response_json(
    prompt: str,
    model: str,
    timeout: int = 120,
    n: int = 1,
    fuse_key_word: Optional[str] = None,
    use_giga_censor: bool = False,
    max_tokens: int = 128,
    max_attempts: int = 5,
) -> requests.Response:
    """
    Send a text generation request to the API.

    Args:
        prompt (str): The input prompt.
        model (str): The model to be used for generation.
        timeout (int): Timeout duration in seconds.
        n (int): Number of responses.
        fuse_key_word (Optional[str]): Additional keyword to include in the prompt.
        use_giga_censor (bool): Whether to use profanity filtering.
        max_tokens (int): Maximum number of tokens in the response.

    Returns:
        requests.Response: API response.
    """
    global token_data
    
    url = "https://beta.saluteai.sberdevices.ru/v1/chat/completions"
    payload = json.dumps({
        "model": model,
        # "messages": [
        #     {
        #         "role": "user",
        #         "content": ' '.join([fuse_key_word, prompt]) if fuse_key_word else prompt
        #     }
        # ],
        "messages": [
            {
                "role": "system",
                "content": giga_system_prompt_i2v
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        "temperature": 0.87,
        "top_p": 0.47,
        "n": n,
        "stream": False,
        "max_tokens": max_tokens,
        "repetition_penalty": 1.07,
        "profanity_check": use_giga_censor
    })

    if token_data is None or not check_auth_token(token_data): 
        token_data = get_auth_token()
    
    headers = {
        'Content-Type': 'application/json',
        'Accept': 'application/json',
        'Authorization': f'Bearer {token_data["access_token"]}'
    }

    attempt_num = 0
    while attempt_num < max_attempts:
        try:
            response = requests.post(url, headers=headers, data=payload, timeout=timeout)
            response_dict = response.json()
            return response_dict
        except:
            time.sleep(5)
            attempt_num += 1
            continue
            
    return response_dict
   
def giga_generate(
    prompt: str, 
    model_version: str = "GigaChat-Max", 
    max_tokens: int = 128,
    max_attempts: int = 5
) -> str:
    """
    Generate text using the GigaChat model.

    Args:
        prompt (str): The input prompt.
        model_version (str): The version of the model to use.
        max_tokens (int): Maximum number of tokens in the response.

    Returns:
        str: Generated text.
    """
    response_dict = get_response_json(
        prompt,
        model_version,
        use_giga_censor=False,
        max_tokens=max_tokens,
        max_attempts=max_attempts,
    )
    try:
        if response_dict['choices'][0]['finish_reason'] == 'blacklist':
            print('GigaCensor triggered!')
            return 'Censored Text'
        else:
            response_str = response_dict['choices'][0]['message']['content']
            return response_str
    except:
        return prompt