yym68686 commited on
Commit
3b159d8
·
1 Parent(s): edb14b7

Add Gemini region load balancing.

Browse files
Files changed (4) hide show
  1. README.md +5 -1
  2. request.py +157 -6
  3. requirements.txt +3 -2
  4. utils.py +36 -1
README.md CHANGED
@@ -55,10 +55,14 @@ providers:
55
  - provider: vertex
56
  project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
57
  private_key: "-----BEGIN PRIVATE KEY-----\nxxxxx\n-----END PRIVATE" # 描述: Google Cloud Vertex AI服务账号的私钥。格式: 一个JSON格式的字符串,包含服务账号的私钥信息。获取方式: 在Google Cloud Console中创建服务账号,生成JSON格式的密钥文件,然后将其内容设置为此环境变量的值。
58
- client_email: xxxxxxxxxx@developer.gserviceaccount.com # 描述: Google Cloud Vertex AI服务账号的电子邮件地址。格式: 通常是形如 "service-account-name@project-id.iam.gserviceaccount.com" 的字符串。获取方式: 在创建服务账号时生成,也可以在Google Cloud Console的"IAM与管理"部分查看服务账号详情获得。
59
  model:
60
  - gemini-1.5-pro
61
  - gemini-1.5-flash
 
 
 
 
62
  tools: true
63
 
64
  - provider: other-provider
 
55
  - provider: vertex
56
  project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
57
  private_key: "-----BEGIN PRIVATE KEY-----\nxxxxx\n-----END PRIVATE" # 描述: Google Cloud Vertex AI服务账号的私钥。格式: 一个JSON格式的字符串,包含服务账号的私钥信息。获取方式: 在Google Cloud Console中创建服务账号,生成JSON格式的密钥文件,然后将其内容设置为此环境变量的值。
58
+ client_email: xxxxxxxxxx@xxxxxxx.gserviceaccount.com # 描述: Google Cloud Vertex AI服务账号的电子邮件地址。格式: 通常是形如 "service-account-name@project-id.iam.gserviceaccount.com" 的字符串。获取方式: 在创建服务账号时生成,也可以在Google Cloud Console的"IAM与管理"部分查看服务账号详情获得。
59
  model:
60
  - gemini-1.5-pro
61
  - gemini-1.5-flash
62
+ - claude-3-5-sonnet@20240620: claude-3-5-sonnet
63
+ - claude-3-opus@20240229: claude-3-opus
64
+ - claude-3-sonnet@20240229: claude-3-sonnet
65
+ - claude-3-haiku@20240307: claude-3-haiku
66
  tools: true
67
 
68
  - provider: other-provider
request.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from models import RequestModel
3
- from log_config import logger
4
 
5
  async def get_image_message(base64_image, engine = None):
6
  if "gpt" == engine:
@@ -222,19 +222,168 @@ def get_access_token(client_email, private_key):
222
  response.raise_for_status()
223
  return response.json()["access_token"]
224
 
225
- async def get_vertex_payload(request, engine, provider):
226
  headers = {
227
  'Content-Type': 'application/json'
228
  }
229
  if provider.get("client_email") and provider.get("private_key"):
230
  access_token = get_access_token(provider['client_email'], provider['private_key'])
231
  headers['Authorization'] = f"Bearer {access_token}"
232
- model = provider['model'][request.model]
 
 
233
  if request.stream:
234
  gemini_stream = "streamGenerateContent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if provider.get("project_id"):
236
  project_id = provider.get("project_id")
237
- url = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}".format(PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  messages = []
240
  systemInstruction = None
@@ -620,8 +769,10 @@ async def get_claude_payload(request, engine, provider):
620
  async def get_payload(request: RequestModel, engine, provider):
621
  if engine == "gemini":
622
  return await get_gemini_payload(request, engine, provider)
623
- elif engine == "vertex":
624
- return await get_vertex_payload(request, engine, provider)
 
 
625
  elif engine == "claude":
626
  return await get_claude_payload(request, engine, provider)
627
  elif engine == "gpt":
 
1
  import json
2
  from models import RequestModel
3
+ from utils import c35s, c3s, c3o, c3h, CircularList
4
 
5
  async def get_image_message(base64_image, engine = None):
6
  if "gpt" == engine:
 
222
  response.raise_for_status()
223
  return response.json()["access_token"]
224
 
225
+ async def get_vertex_gemini_payload(request, engine, provider):
226
  headers = {
227
  'Content-Type': 'application/json'
228
  }
229
  if provider.get("client_email") and provider.get("private_key"):
230
  access_token = get_access_token(provider['client_email'], provider['private_key'])
231
  headers['Authorization'] = f"Bearer {access_token}"
232
+ if provider.get("project_id"):
233
+ project_id = provider.get("project_id")
234
+
235
  if request.stream:
236
  gemini_stream = "streamGenerateContent"
237
+ model = provider['model'][request.model]
238
+ location = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
239
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream)
240
+
241
+ messages = []
242
+ systemInstruction = None
243
+ function_arguments = None
244
+ for msg in request.messages:
245
+ if msg.role == "assistant":
246
+ msg.role = "model"
247
+ tool_calls = None
248
+ if isinstance(msg.content, list):
249
+ content = []
250
+ for item in msg.content:
251
+ if item.type == "text":
252
+ text_message = await get_text_message(msg.role, item.text, engine)
253
+ content.append(text_message)
254
+ elif item.type == "image_url":
255
+ image_message = await get_image_message(item.image_url.url, engine)
256
+ content.append(image_message)
257
+ else:
258
+ content = [{"text": msg.content}]
259
+ tool_calls = msg.tool_calls
260
+
261
+ if tool_calls:
262
+ tool_call = tool_calls[0]
263
+ function_arguments = {
264
+ "functionCall": {
265
+ "name": tool_call.function.name,
266
+ "args": json.loads(tool_call.function.arguments)
267
+ }
268
+ }
269
+ messages.append(
270
+ {
271
+ "role": "model",
272
+ "parts": [function_arguments]
273
+ }
274
+ )
275
+ elif msg.role == "tool":
276
+ function_call_name = function_arguments["functionCall"]["name"]
277
+ messages.append(
278
+ {
279
+ "role": "function",
280
+ "parts": [{
281
+ "functionResponse": {
282
+ "name": function_call_name,
283
+ "response": {
284
+ "name": function_call_name,
285
+ "content": {
286
+ "result": msg.content,
287
+ }
288
+ }
289
+ }
290
+ }]
291
+ }
292
+ )
293
+ elif msg.role != "system":
294
+ messages.append({"role": msg.role, "parts": content})
295
+ elif msg.role == "system":
296
+ systemInstruction = {"parts": content}
297
+
298
+
299
+ payload = {
300
+ "contents": messages,
301
+ # "safetySettings": [
302
+ # {
303
+ # "category": "HARM_CATEGORY_HARASSMENT",
304
+ # "threshold": "BLOCK_NONE"
305
+ # },
306
+ # {
307
+ # "category": "HARM_CATEGORY_HATE_SPEECH",
308
+ # "threshold": "BLOCK_NONE"
309
+ # },
310
+ # {
311
+ # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
312
+ # "threshold": "BLOCK_NONE"
313
+ # },
314
+ # {
315
+ # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
316
+ # "threshold": "BLOCK_NONE"
317
+ # }
318
+ # ]
319
+ "generationConfig": {
320
+ "temperature": 0.5,
321
+ "max_output_tokens": 8192,
322
+ "top_k": 40,
323
+ "top_p": 0.95
324
+ },
325
+ }
326
+ if systemInstruction:
327
+ payload["system_instruction"] = systemInstruction
328
+
329
+ miss_fields = [
330
+ 'model',
331
+ 'messages',
332
+ 'stream',
333
+ 'tool_choice',
334
+ 'temperature',
335
+ 'top_p',
336
+ 'max_tokens',
337
+ 'presence_penalty',
338
+ 'frequency_penalty',
339
+ 'n',
340
+ 'user',
341
+ 'include_usage',
342
+ 'logprobs',
343
+ 'top_logprobs'
344
+ ]
345
+
346
+ for field, value in request.model_dump(exclude_unset=True).items():
347
+ if field not in miss_fields and value is not None:
348
+ if field == "tools":
349
+ payload.update({
350
+ "tools": [{
351
+ "function_declarations": [tool["function"] for tool in value]
352
+ }],
353
+ "tool_config": {
354
+ "function_calling_config": {
355
+ "mode": "AUTO"
356
+ }
357
+ }
358
+ })
359
+ else:
360
+ payload[field] = value
361
+
362
+ return url, headers, payload
363
+
364
+ async def get_vertex_claude_payload(request, engine, provider):
365
+ headers = {
366
+ 'Content-Type': 'application/json'
367
+ }
368
+ if provider.get("client_email") and provider.get("private_key"):
369
+ access_token = get_access_token(provider['client_email'], provider['private_key'])
370
+ headers['Authorization'] = f"Bearer {access_token}"
371
  if provider.get("project_id"):
372
  project_id = provider.get("project_id")
373
+
374
+ model = provider['model'][request.model]
375
+ if "claude-3-5-sonnet" in model:
376
+ location = c35s
377
+ elif "claude-3-opus" in model:
378
+ location = c3o
379
+ elif "claude-3-sonnet" in model:
380
+ location = c3s
381
+ elif "claude-3-haiku" in model:
382
+ location = c3h
383
+
384
+ if request.stream:
385
+ claude_stream = "streamRawPredict"
386
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
387
 
388
  messages = []
389
  systemInstruction = None
 
769
  async def get_payload(request: RequestModel, engine, provider):
770
  if engine == "gemini":
771
  return await get_gemini_payload(request, engine, provider)
772
+ elif engine == "vertex" and "gemini" in provider['model'][request.model]:
773
+ return await get_vertex_gemini_payload(request, engine, provider)
774
+ elif engine == "vertex" and "claude" in provider['model'][request.model]:
775
+ return await get_vertex_claude_payload(request, engine, provider)
776
  elif engine == "claude":
777
  return await get_claude_payload(request, engine, provider)
778
  elif engine == "gpt":
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- httpx[http2]
2
  pyyaml
3
  pytest
4
  uvicorn
5
- fastapi
 
 
 
 
1
  pyyaml
2
  pytest
3
  uvicorn
4
+ fastapi
5
+ httpx[http2]
6
+ cryptography
utils.py CHANGED
@@ -185,4 +185,39 @@ def get_all_models(config):
185
  }
186
  all_models.append(model_info)
187
 
188
- return all_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  }
186
  all_models.append(model_info)
187
 
188
+ return all_models
189
+
190
+ # 【GCP-Vertex AI 目前有這些區域可用】 https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude?hl=zh_cn
191
+ # c3.5s
192
+ # us-east5
193
+ # europe-west1
194
+
195
+ # c3s
196
+ # us-east5
197
+ # us-central1
198
+ # asia-southeast1
199
+
200
+ # c3o
201
+ # us-east5
202
+
203
+ # c3h
204
+ # us-east5
205
+ # us-central1
206
+ # europe-west1
207
+ # europe-west4
208
+ from collections import deque
209
+ class CircularList:
210
+ def __init__(self, items):
211
+ self.queue = deque(items)
212
+
213
+ def next(self):
214
+ if not self.queue:
215
+ return None
216
+ item = self.queue.popleft()
217
+ self.queue.append(item)
218
+ return item
219
+
220
+ c35s = CircularList(["us-east5", "europe-west1"])
221
+ c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
222
+ c3o = CircularList(["us-east5"])
223
+ c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])