KevinHuSh commited on
Commit
db713d9
·
1 Parent(s): 30d6885

Add 2 embeding models from OpenAI (#812)

Browse files

### What problem does this PR solve?

#810

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/db/init_data.py CHANGED
@@ -16,6 +16,7 @@
16
  import os
17
  import time
18
  import uuid
 
19
 
20
  from api.db import LLMType, UserTenantRole
21
  from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
@@ -166,6 +167,18 @@ def init_llm_factory():
166
  "tags": "TEXT EMBEDDING,8K",
167
  "max_tokens": 8191,
168
  "model_type": LLMType.EMBEDDING.value
 
 
 
 
 
 
 
 
 
 
 
 
169
  }, {
170
  "fid": factory_infos[0]["name"],
171
  "llm_name": "whisper-1",
@@ -376,6 +389,23 @@ def init_llm_factory():
376
  LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
377
  LLMService.filter_delete([LLMService.model.fid == "QAnything"])
378
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  """
380
  drop table llm;
381
  drop table llm_factories;
 
16
  import os
17
  import time
18
  import uuid
19
+ from copy import deepcopy
20
 
21
  from api.db import LLMType, UserTenantRole
22
  from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
 
167
  "tags": "TEXT EMBEDDING,8K",
168
  "max_tokens": 8191,
169
  "model_type": LLMType.EMBEDDING.value
170
+ }, {
171
+ "fid": factory_infos[0]["name"],
172
+ "llm_name": "text-embedding-3-small",
173
+ "tags": "TEXT EMBEDDING,8K",
174
+ "max_tokens": 8191,
175
+ "model_type": LLMType.EMBEDDING.value
176
+ }, {
177
+ "fid": factory_infos[0]["name"],
178
+ "llm_name": "text-embedding-3-large",
179
+ "tags": "TEXT EMBEDDING,8K",
180
+ "max_tokens": 8191,
181
+ "model_type": LLMType.EMBEDDING.value
182
  }, {
183
  "fid": factory_infos[0]["name"],
184
  "llm_name": "whisper-1",
 
389
  LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
390
  LLMService.filter_delete([LLMService.model.fid == "QAnything"])
391
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
392
+ ## insert openai two embedding models to the current openai user.
393
+ print("Start to insert 2 OpenAI embedding models...")
394
+ tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
395
+ for tid in tenant_ids:
396
+ for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
397
+ row = row.to_dict()
398
+ row["model_type"] = LLMType.EMBEDDING.value
399
+ row["llm_name"] = "text-embedding-3-small"
400
+ row["used_tokens"] = 0
401
+ try:
402
+ TenantLLMService.save(**row)
403
+ row = deepcopy(row)
404
+ row["llm_name"] = "text-embedding-3-large"
405
+ TenantLLMService.save(**row)
406
+ except Exception as e:
407
+ pass
408
+ break
409
  """
410
  drop table llm;
411
  drop table llm_factories;
api/db/services/llm_service.py CHANGED
@@ -135,6 +135,16 @@ class TenantLLMService(CommonService):
135
  .execute()
136
  return num
137
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  class LLMBundle(object):
140
  def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
 
135
  .execute()
136
  return num
137
 
138
+ @classmethod
139
+ @DB.connection_context()
140
+ def get_openai_models(cls):
141
+ objs = cls.model.select().where(
142
+ (cls.model.llm_factory == "OpenAI"),
143
+ ~(cls.model.llm_name == "text-embedding-3-small"),
144
+ ~(cls.model.llm_name == "text-embedding-3-large")
145
+ ).dicts()
146
+ return list(objs)
147
+
148
 
149
  class LLMBundle(object):
150
  def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):