import requests import base64 import uuid import json import time from typing import Dict, Optional, Any from dotenv import load_dotenv import os from src.prompts import giga_system_prompt_i2v # Load environment variables from .env file load_dotenv() AUTH_TOKEN = os.getenv("AUTH_TOKEN") COOKIE = os.getenv("COOKIE") # print(f"AUTH_TOKEN: {AUTH_TOKEN}") # print(f"COOKIE: {COOKIE}") def get_auth_token(timeout: float = 2) -> Dict[str, Any]: """ Get authentication token. Args: timeout (float): Timeout duration in seconds. Returns: Dict[str, Any]: Dictionary containing the access token and its expiration time. """ url = "https://beta.saluteai.sberdevices.ru/v1/token" payload = 'scope=GIGACHAT_API_CORP' headers = { 'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json', 'RqUID': str(uuid.uuid4()), 'Cookie': COOKIE, 'Authorization': f'Basic {AUTH_TOKEN}' } response = requests.post(url, headers=headers, data=payload, timeout=timeout) response_dict = response.json() return { 'access_token': response_dict['tok'], 'expires_at': response_dict['exp'] } def check_auth_token(token_data: Dict[str, Any]) -> bool: """ Check if the authentication token is valid. Args: token_data (Dict[str, Any]): Dictionary containing token data. Returns: bool: True if the token is valid, False otherwise. """ return token_data['expires_at'] - time.time() > 5 token_data: Optional[Dict[str, Any]] = None def get_response_json( prompt: str, model: str, # timeout: int = 120, timeout: int = 5, n: int = 1, fuse_key_word: Optional[str] = None, use_giga_censor: bool = False, max_tokens: int = 128, max_attempts: int = 5, ) -> requests.Response: """ Send a text generation request to the API. Args: prompt (str): The input prompt. model (str): The model to be used for generation. timeout (int): Timeout duration in seconds. n (int): Number of responses. fuse_key_word (Optional[str]): Additional keyword to include in the prompt. use_giga_censor (bool): Whether to use profanity filtering. max_tokens (int): Maximum number of tokens in the response. Returns: requests.Response: API response. """ global token_data url = "https://beta.saluteai.sberdevices.ru/v1/chat/completions" payload = json.dumps({ "model": model, # "messages": [ # { # "role": "user", # "content": ' '.join([fuse_key_word, prompt]) if fuse_key_word else prompt # } # ], "messages": [ { "role": "system", "content": giga_system_prompt_i2v }, { "role": "user", "content": prompt } ], "temperature": 0.87, "top_p": 0.47, "n": n, "stream": False, "max_tokens": max_tokens, "repetition_penalty": 1.07, "profanity_check": use_giga_censor }) if token_data is None or not check_auth_token(token_data): token_data = get_auth_token() headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', 'Authorization': f'Bearer {token_data["access_token"]}' } attempt_num = 0 while attempt_num < max_attempts: try: response = requests.post(url, headers=headers, data=payload, timeout=timeout) response_dict = response.json() return response_dict except: time.sleep(5) attempt_num += 1 continue return response_dict def giga_generate( prompt: str, model_version: str = "GigaChat-Max", max_tokens: int = 128, max_attempts: int = 5 ) -> str: """ Generate text using the GigaChat model. Args: prompt (str): The input prompt. model_version (str): The version of the model to use. max_tokens (int): Maximum number of tokens in the response. Returns: str: Generated text. """ response_dict = get_response_json( prompt, model_version, use_giga_censor=False, max_tokens=max_tokens, max_attempts=max_attempts, ) try: if response_dict['choices'][0]['finish_reason'] == 'blacklist': print('GigaCensor triggered!') return 'Censored Text' else: response_str = response_dict['choices'][0]['message']['content'] return response_str except: return prompt