ai-forever's picture
add max_attempts
cc516fa verified
raw
history blame
4.71 kB
import requests
import base64
import uuid
import json
import time
from typing import Dict, Optional, Any
from dotenv import load_dotenv
import os
from src.prompts import giga_system_prompt_i2v
# Load environment variables from .env file
load_dotenv()
AUTH_TOKEN = os.getenv("AUTH_TOKEN")
COOKIE = os.getenv("COOKIE")
# print(f"AUTH_TOKEN: {AUTH_TOKEN}")
# print(f"COOKIE: {COOKIE}")
def get_auth_token(timeout: float = 2) -> Dict[str, Any]:
"""
Get authentication token.
Args:
timeout (float): Timeout duration in seconds.
Returns:
Dict[str, Any]: Dictionary containing the access token and its expiration time.
"""
url = "https://beta.saluteai.sberdevices.ru/v1/token"
payload = 'scope=GIGACHAT_API_CORP'
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
'RqUID': str(uuid.uuid4()),
'Cookie': COOKIE,
'Authorization': f'Basic {AUTH_TOKEN}'
}
response = requests.post(url, headers=headers, data=payload, timeout=timeout)
response_dict = response.json()
return {
'access_token': response_dict['tok'],
'expires_at': response_dict['exp']
}
def check_auth_token(token_data: Dict[str, Any]) -> bool:
"""
Check if the authentication token is valid.
Args:
token_data (Dict[str, Any]): Dictionary containing token data.
Returns:
bool: True if the token is valid, False otherwise.
"""
return token_data['expires_at'] - time.time() > 5
token_data: Optional[Dict[str, Any]] = None
def get_response_json(
prompt: str,
model: str,
timeout: int = 120,
n: int = 1,
fuse_key_word: Optional[str] = None,
use_giga_censor: bool = False,
max_tokens: int = 128,
max_attempts: int = 5,
) -> requests.Response:
"""
Send a text generation request to the API.
Args:
prompt (str): The input prompt.
model (str): The model to be used for generation.
timeout (int): Timeout duration in seconds.
n (int): Number of responses.
fuse_key_word (Optional[str]): Additional keyword to include in the prompt.
use_giga_censor (bool): Whether to use profanity filtering.
max_tokens (int): Maximum number of tokens in the response.
Returns:
requests.Response: API response.
"""
global token_data
url = "https://beta.saluteai.sberdevices.ru/v1/chat/completions"
payload = json.dumps({
"model": model,
# "messages": [
# {
# "role": "user",
# "content": ' '.join([fuse_key_word, prompt]) if fuse_key_word else prompt
# }
# ],
"messages": [
{
"role": "system",
"content": giga_system_prompt_i2v
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.87,
"top_p": 0.47,
"n": n,
"stream": False,
"max_tokens": max_tokens,
"repetition_penalty": 1.07,
"profanity_check": use_giga_censor
})
if token_data is None or not check_auth_token(token_data):
token_data = get_auth_token()
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {token_data["access_token"]}'
}
attempt_num = 0
while attempt_num < max_attempts:
try:
response = requests.post(url, headers=headers, data=payload, timeout=timeout)
response_dict = response.json()
except:
time.sleep(5)
attempt_num += 1
continue
return response_dict
def giga_generate(
prompt: str,
model_version: str = "GigaChat-Max",
max_tokens: int = 128,
max_attempts: int = 5
) -> str:
"""
Generate text using the GigaChat model.
Args:
prompt (str): The input prompt.
model_version (str): The version of the model to use.
max_tokens (int): Maximum number of tokens in the response.
Returns:
str: Generated text.
"""
response_dict = get_response_json(
prompt,
model_version,
use_giga_censor=False,
max_tokens=max_tokens,
max_attempts=max_attempts,
)
try:
if response_dict['choices'][0]['finish_reason'] == 'blacklist':
print('GigaCensor triggered!')
return 'Censored Text'
else:
response_str = response_dict['choices'][0]['message']['content']
return response_str
except:
return prompt