KevinHuSh
commited on
Commit
·
db713d9
1
Parent(s):
30d6885
Add 2 embeding models from OpenAI (#812)
Browse files### What problem does this PR solve?
#810
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/db/init_data.py +30 -0
- api/db/services/llm_service.py +10 -0
api/db/init_data.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
import os
|
17 |
import time
|
18 |
import uuid
|
|
|
19 |
|
20 |
from api.db import LLMType, UserTenantRole
|
21 |
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
@@ -166,6 +167,18 @@ def init_llm_factory():
|
|
166 |
"tags": "TEXT EMBEDDING,8K",
|
167 |
"max_tokens": 8191,
|
168 |
"model_type": LLMType.EMBEDDING.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
}, {
|
170 |
"fid": factory_infos[0]["name"],
|
171 |
"llm_name": "whisper-1",
|
@@ -376,6 +389,23 @@ def init_llm_factory():
|
|
376 |
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
|
377 |
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
|
378 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
"""
|
380 |
drop table llm;
|
381 |
drop table llm_factories;
|
|
|
16 |
import os
|
17 |
import time
|
18 |
import uuid
|
19 |
+
from copy import deepcopy
|
20 |
|
21 |
from api.db import LLMType, UserTenantRole
|
22 |
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
|
|
167 |
"tags": "TEXT EMBEDDING,8K",
|
168 |
"max_tokens": 8191,
|
169 |
"model_type": LLMType.EMBEDDING.value
|
170 |
+
}, {
|
171 |
+
"fid": factory_infos[0]["name"],
|
172 |
+
"llm_name": "text-embedding-3-small",
|
173 |
+
"tags": "TEXT EMBEDDING,8K",
|
174 |
+
"max_tokens": 8191,
|
175 |
+
"model_type": LLMType.EMBEDDING.value
|
176 |
+
}, {
|
177 |
+
"fid": factory_infos[0]["name"],
|
178 |
+
"llm_name": "text-embedding-3-large",
|
179 |
+
"tags": "TEXT EMBEDDING,8K",
|
180 |
+
"max_tokens": 8191,
|
181 |
+
"model_type": LLMType.EMBEDDING.value
|
182 |
}, {
|
183 |
"fid": factory_infos[0]["name"],
|
184 |
"llm_name": "whisper-1",
|
|
|
389 |
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
|
390 |
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
|
391 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
392 |
+
## insert openai two embedding models to the current openai user.
|
393 |
+
print("Start to insert 2 OpenAI embedding models...")
|
394 |
+
tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
|
395 |
+
for tid in tenant_ids:
|
396 |
+
for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
|
397 |
+
row = row.to_dict()
|
398 |
+
row["model_type"] = LLMType.EMBEDDING.value
|
399 |
+
row["llm_name"] = "text-embedding-3-small"
|
400 |
+
row["used_tokens"] = 0
|
401 |
+
try:
|
402 |
+
TenantLLMService.save(**row)
|
403 |
+
row = deepcopy(row)
|
404 |
+
row["llm_name"] = "text-embedding-3-large"
|
405 |
+
TenantLLMService.save(**row)
|
406 |
+
except Exception as e:
|
407 |
+
pass
|
408 |
+
break
|
409 |
"""
|
410 |
drop table llm;
|
411 |
drop table llm_factories;
|
api/db/services/llm_service.py
CHANGED
@@ -135,6 +135,16 @@ class TenantLLMService(CommonService):
|
|
135 |
.execute()
|
136 |
return num
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
class LLMBundle(object):
|
140 |
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
|
|
135 |
.execute()
|
136 |
return num
|
137 |
|
138 |
+
@classmethod
|
139 |
+
@DB.connection_context()
|
140 |
+
def get_openai_models(cls):
|
141 |
+
objs = cls.model.select().where(
|
142 |
+
(cls.model.llm_factory == "OpenAI"),
|
143 |
+
~(cls.model.llm_name == "text-embedding-3-small"),
|
144 |
+
~(cls.model.llm_name == "text-embedding-3-large")
|
145 |
+
).dicts()
|
146 |
+
return list(objs)
|
147 |
+
|
148 |
|
149 |
class LLMBundle(object):
|
150 |
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|