ai-forever commited on
Commit
cc516fa
1 Parent(s): 4535fc2

add max_attempts

Browse files
Files changed (1) hide show
  1. src/gigachat.py +27 -13
src/gigachat.py CHANGED
@@ -57,7 +57,7 @@ def check_auth_token(token_data: Dict[str, Any]) -> bool:
57
 
58
  token_data: Optional[Dict[str, Any]] = None
59
 
60
- def get_response(
61
  prompt: str,
62
  model: str,
63
  timeout: int = 120,
@@ -65,6 +65,7 @@ def get_response(
65
  fuse_key_word: Optional[str] = None,
66
  use_giga_censor: bool = False,
67
  max_tokens: int = 128,
 
68
  ) -> requests.Response:
69
  """
70
  Send a text generation request to the API.
@@ -119,13 +120,24 @@ def get_response(
119
  'Accept': 'application/json',
120
  'Authorization': f'Bearer {token_data["access_token"]}'
121
  }
122
- response = requests.post(url, headers=headers, data=payload, timeout=timeout)
123
- return response
 
 
 
 
 
 
 
 
 
 
124
 
125
  def giga_generate(
126
  prompt: str,
127
  model_version: str = "GigaChat-Max",
128
- max_tokens: int = 128
 
129
  ) -> str:
130
  """
131
  Generate text using the GigaChat model.
@@ -138,17 +150,19 @@ def giga_generate(
138
  Returns:
139
  str: Generated text.
140
  """
141
- response = get_response(
142
  prompt,
143
  model_version,
144
  use_giga_censor=False,
145
  max_tokens=max_tokens,
 
146
  )
147
- response_dict = response.json()
148
-
149
- if response_dict['choices'][0]['finish_reason'] == 'blacklist':
150
- print('GigaCensor triggered!')
151
- return 'Censored Text'
152
- else:
153
- response_str = response_dict['choices'][0]['message']['content']
154
- return response_str
 
 
57
 
58
  token_data: Optional[Dict[str, Any]] = None
59
 
60
+ def get_response_json(
61
  prompt: str,
62
  model: str,
63
  timeout: int = 120,
 
65
  fuse_key_word: Optional[str] = None,
66
  use_giga_censor: bool = False,
67
  max_tokens: int = 128,
68
+ max_attempts: int = 5,
69
  ) -> requests.Response:
70
  """
71
  Send a text generation request to the API.
 
120
  'Accept': 'application/json',
121
  'Authorization': f'Bearer {token_data["access_token"]}'
122
  }
123
+
124
+ attempt_num = 0
125
+ while attempt_num < max_attempts:
126
+ try:
127
+ response = requests.post(url, headers=headers, data=payload, timeout=timeout)
128
+ response_dict = response.json()
129
+ except:
130
+ time.sleep(5)
131
+ attempt_num += 1
132
+ continue
133
+
134
+ return response_dict
135
 
136
  def giga_generate(
137
  prompt: str,
138
  model_version: str = "GigaChat-Max",
139
+ max_tokens: int = 128,
140
+ max_attempts: int = 5
141
  ) -> str:
142
  """
143
  Generate text using the GigaChat model.
 
150
  Returns:
151
  str: Generated text.
152
  """
153
+ response_dict = get_response_json(
154
  prompt,
155
  model_version,
156
  use_giga_censor=False,
157
  max_tokens=max_tokens,
158
+ max_attempts=max_attempts,
159
  )
160
+ try:
161
+ if response_dict['choices'][0]['finish_reason'] == 'blacklist':
162
+ print('GigaCensor triggered!')
163
+ return 'Censored Text'
164
+ else:
165
+ response_str = response_dict['choices'][0]['message']['content']
166
+ return response_str
167
+ except:
168
+ return prompt