Kevin Hu commited on
Commit
244188b
·
1 Parent(s): 04c2182

refactor add LLM (#2508)

Browse files

### What problem does this PR solve?

#2487

### Type of change

- [x] Refactoring

Files changed (2) hide show
  1. api/apps/llm_app.py +25 -22
  2. rag/llm/chat_model.py +1 -1
api/apps/llm_app.py CHANGED
@@ -13,6 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  from flask import request
17
  from flask_login import login_required, current_user
18
  from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
@@ -126,55 +128,56 @@ def add_llm():
126
  req = request.json
127
  factory = req["llm_factory"]
128
 
 
 
 
 
129
  if factory == "VolcEngine":
130
  # For VolcEngine, due to its special authentication method
131
  # Assemble ark_api_key endpoint_id into api_key
132
  llm_name = req["llm_name"]
133
- api_key = f'{{ "ark_api_key":"{req.get("ark_api_key", "")}", "ep_id":"{req.get("endpoint_id", "")}" }}'
 
134
  elif factory == "Tencent Hunyuan":
135
- api_key = '{' + f'"hunyuan_sid": "{req.get("hunyuan_sid", "")}", ' \
136
- f'"hunyuan_sk": "{req.get("hunyuan_sk", "")}"' + '}'
137
- req["api_key"] = api_key
138
  return set_api_key()
 
139
  elif factory == "Tencent Cloud":
140
- api_key = '{' + f'"tencent_cloud_sid": "{req.get("tencent_cloud_sid", "")}", ' \
141
- f'"tencent_cloud_sk": "{req.get("tencent_cloud_sk", "")}"' + '}'
142
- req["api_key"] = api_key
143
  elif factory == "Bedrock":
144
  # For Bedrock, due to its special authentication method
145
  # Assemble bedrock_ak, bedrock_sk, bedrock_region
146
  llm_name = req["llm_name"]
147
- api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
148
- f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
149
- f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
150
  elif factory == "LocalAI":
151
  llm_name = req["llm_name"]+"___LocalAI"
152
  api_key = "xxxxxxxxxxxxxxx"
 
153
  elif factory == "OpenAI-API-Compatible":
154
  llm_name = req["llm_name"]+"___OpenAI-API"
155
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
 
156
  elif factory =="XunFei Spark":
157
  llm_name = req["llm_name"]
158
- api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
 
159
  elif factory == "BaiduYiyan":
160
  llm_name = req["llm_name"]
161
- api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \
162
- f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}'
163
  elif factory == "Fish Audio":
164
  llm_name = req["llm_name"]
165
- api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \
166
- f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}'
167
  elif factory == "Google Cloud":
168
  llm_name = req["llm_name"]
169
- api_key = (
170
- "{" + f'"google_project_id": "{req.get("google_project_id", "")}", '
171
- f'"google_region": "{req.get("google_region", "")}", '
172
- f'"google_service_account_key": "{req.get("google_service_account_key", "")}"'
173
- + "}"
174
- )
175
  else:
176
  llm_name = req["llm_name"]
177
- api_key = req.get("api_key","xxxxxxxxxxxxxxx")
178
 
179
  llm = {
180
  "tenant_id": current_user.id,
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import json
17
+
18
  from flask import request
19
  from flask_login import login_required, current_user
20
  from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
 
128
  req = request.json
129
  factory = req["llm_factory"]
130
 
131
+ def apikey_json(keys):
132
+ nonlocal req
133
+ return json.dumps({k: req.get(k, "") for k in keys})
134
+
135
  if factory == "VolcEngine":
136
  # For VolcEngine, due to its special authentication method
137
  # Assemble ark_api_key endpoint_id into api_key
138
  llm_name = req["llm_name"]
139
+ api_key = apikey_json(["ark_api_key", "endpoint_id"])
140
+
141
  elif factory == "Tencent Hunyuan":
142
+ req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"])
 
 
143
  return set_api_key()
144
+
145
  elif factory == "Tencent Cloud":
146
+ req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"])
147
+
 
148
  elif factory == "Bedrock":
149
  # For Bedrock, due to its special authentication method
150
  # Assemble bedrock_ak, bedrock_sk, bedrock_region
151
  llm_name = req["llm_name"]
152
+ api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
153
+
 
154
  elif factory == "LocalAI":
155
  llm_name = req["llm_name"]+"___LocalAI"
156
  api_key = "xxxxxxxxxxxxxxx"
157
+
158
  elif factory == "OpenAI-API-Compatible":
159
  llm_name = req["llm_name"]+"___OpenAI-API"
160
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
161
+
162
  elif factory =="XunFei Spark":
163
  llm_name = req["llm_name"]
164
+ api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
165
+
166
  elif factory == "BaiduYiyan":
167
  llm_name = req["llm_name"]
168
+ api_key = apikey_json(["yiyan_ak", "yiyan_sk"])
169
+
170
  elif factory == "Fish Audio":
171
  llm_name = req["llm_name"]
172
+ api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])
173
+
174
  elif factory == "Google Cloud":
175
  llm_name = req["llm_name"]
176
+ api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
177
+
 
 
 
 
178
  else:
179
  llm_name = req["llm_name"]
180
+ api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
181
 
182
  llm = {
183
  "tenant_id": current_user.id,
rag/llm/chat_model.py CHANGED
@@ -458,7 +458,7 @@ class VolcEngineChat(Base):
458
  """
459
  base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
460
  ark_api_key = json.loads(key).get('ark_api_key', '')
461
- model_name = json.loads(key).get('ep_id', '')
462
  super().__init__(ark_api_key, model_name, base_url)
463
 
464
 
 
458
  """
459
  base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3'
460
  ark_api_key = json.loads(key).get('ark_api_key', '')
461
+ model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '')
462
  super().__init__(ark_api_key, model_name, base_url)
463
 
464