H commited on
Commit
afe703e
1 Parent(s): f515d5f

Chat Use CVmodel (#1607)

Browse files

### What problem does this PR solve?

#1230

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/db/services/dialog_service.py CHANGED
@@ -13,6 +13,8 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
 
16
  import re
17
  from copy import deepcopy
18
 
@@ -26,6 +28,7 @@ from rag.app.resume import forbidden_select_fields4resume
26
  from rag.nlp import keyword_extraction
27
  from rag.nlp.search import index_name
28
  from rag.utils import rmSpace, num_tokens_from_string, encoder
 
29
 
30
 
31
  class DialogService(CommonService):
@@ -73,6 +76,15 @@ def message_fit_in(msg, max_length=4000):
73
  return max_length, msg
74
 
75
 
 
 
 
 
 
 
 
 
 
76
  def chat(dialog, messages, stream=True, **kwargs):
77
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
78
  llm = LLMService.query(llm_name=dialog.llm_id)
@@ -91,7 +103,10 @@ def chat(dialog, messages, stream=True, **kwargs):
91
 
92
  questions = [m["content"] for m in messages if m["role"] == "user"]
93
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
94
- chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
 
 
 
95
 
96
  prompt_config = dialog.prompt_config
97
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
@@ -328,7 +343,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
328
 
329
 
330
  def relevant(tenant_id, llm_id, question, contents: list):
331
- chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
 
 
 
332
  prompt = """
333
  You are a grader assessing relevance of a retrieved document to a user question.
334
  It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
@@ -347,7 +365,10 @@ def relevant(tenant_id, llm_id, question, contents: list):
347
 
348
 
349
  def rewrite(tenant_id, llm_id, question):
350
- chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
 
 
 
351
  prompt = """
352
  You are an expert at query expansion to generate a paraphrasing of a question.
353
  I can't retrieval relevant information from the knowledge base by using user's question directly.
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import os
17
+ import json
18
  import re
19
  from copy import deepcopy
20
 
 
28
  from rag.nlp import keyword_extraction
29
  from rag.nlp.search import index_name
30
  from rag.utils import rmSpace, num_tokens_from_string, encoder
31
+ from api.utils.file_utils import get_project_base_directory
32
 
33
 
34
  class DialogService(CommonService):
 
76
  return max_length, msg
77
 
78
 
79
+ def llm_id2llm_type(llm_id):
80
+ fnm = os.path.join(get_project_base_directory(), "conf")
81
+ llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
82
+ for llm_factory in llm_factories["factory_llm_infos"]:
83
+ for llm in llm_factory["llm"]:
84
+ if llm_id == llm["llm_name"]:
85
+ return llm["model_type"].strip(",")[-1]
86
+
87
+
88
  def chat(dialog, messages, stream=True, **kwargs):
89
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
90
  llm = LLMService.query(llm_name=dialog.llm_id)
 
103
 
104
  questions = [m["content"] for m in messages if m["role"] == "user"]
105
  embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
106
+ if llm_id2llm_type(dialog.llm_id) == "image2text":
107
+ chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
108
+ else:
109
+ chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
110
 
111
  prompt_config = dialog.prompt_config
112
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
 
343
 
344
 
345
  def relevant(tenant_id, llm_id, question, contents: list):
346
+ if llm_id2llm_type(llm_id) == "image2text":
347
+ chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
348
+ else:
349
+ chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
350
  prompt = """
351
  You are a grader assessing relevance of a retrieved document to a user question.
352
  It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
 
365
 
366
 
367
  def rewrite(tenant_id, llm_id, question):
368
+ if llm_id2llm_type(llm_id) == "image2text":
369
+ chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
370
+ else:
371
+ chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
372
  prompt = """
373
  You are an expert at query expansion to generate a paraphrasing of a question.
374
  I can't retrieval relevant information from the knowledge base by using user's question directly.
api/db/services/llm_service.py CHANGED
@@ -70,7 +70,7 @@ class TenantLLMService(CommonService):
70
  elif llm_type == LLMType.SPEECH2TEXT.value:
71
  mdlnm = tenant.asr_id
72
  elif llm_type == LLMType.IMAGE2TEXT.value:
73
- mdlnm = tenant.img2txt_id
74
  elif llm_type == LLMType.CHAT.value:
75
  mdlnm = tenant.llm_id if not llm_name else llm_name
76
  elif llm_type == LLMType.RERANK:
 
70
  elif llm_type == LLMType.SPEECH2TEXT.value:
71
  mdlnm = tenant.asr_id
72
  elif llm_type == LLMType.IMAGE2TEXT.value:
73
+ mdlnm = tenant.img2txt_id if not llm_name else llm_name
74
  elif llm_type == LLMType.CHAT.value:
75
  mdlnm = tenant.llm_id if not llm_name else llm_name
76
  elif llm_type == LLMType.RERANK:
rag/llm/cv_model.py CHANGED
@@ -26,6 +26,7 @@ from io import BytesIO
26
  import json
27
  import requests
28
 
 
29
  from api.utils import get_uuid
30
  from api.utils.file_utils import get_project_base_directory
31
 
@@ -36,7 +37,60 @@ class Base(ABC):
36
 
37
  def describe(self, image, max_tokens=300):
38
  raise NotImplementedError("Please implement encode method!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def image2base64(self, image):
41
  if isinstance(image, bytes):
42
  return base64.b64encode(image).decode("utf-8")
@@ -68,6 +122,21 @@ class Base(ABC):
68
  }
69
  ]
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  class GptV4(Base):
73
  def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
@@ -140,6 +209,12 @@ class QWenCV(Base):
140
  }
141
  ]
142
 
 
 
 
 
 
 
143
  def describe(self, image, max_tokens=300):
144
  from http import HTTPStatus
145
  from dashscope import MultiModalConversation
@@ -149,6 +224,66 @@ class QWenCV(Base):
149
  return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
150
  return response.message, 0
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  class Zhipu4V(Base):
154
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
@@ -166,6 +301,59 @@ class Zhipu4V(Base):
166
  )
167
  return res.choices[0].message.content.strip(), res.usage.total_tokens
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  class OllamaCV(Base):
171
  def __init__(self, key, model_name, lang="Chinese", **kwargs):
@@ -188,6 +376,63 @@ class OllamaCV(Base):
188
  except Exception as e:
189
  return "**ERROR**: " + str(e), 0
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  class LocalAICV(Base):
193
  def __init__(self, key, model_name, base_url, lang="Chinese"):
@@ -236,7 +481,7 @@ class XinferenceCV(Base):
236
 
237
  class GeminiCV(Base):
238
  def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
239
- from google.generativeai import client,GenerativeModel
240
  client.configure(api_key=key)
241
  _client = client.get_default_generative_client()
242
  self.model_name = model_name
@@ -258,6 +503,59 @@ class GeminiCV(Base):
258
  )
259
  return res.text,res.usage_metadata.total_token_count
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  class OpenRouterCV(Base):
263
  def __init__(
 
26
  import json
27
  import requests
28
 
29
+ from rag.nlp import is_english
30
  from api.utils import get_uuid
31
  from api.utils.file_utils import get_project_base_directory
32
 
 
37
 
38
  def describe(self, image, max_tokens=300):
39
  raise NotImplementedError("Please implement encode method!")
40
+
41
+ def chat(self, system, history, gen_conf, image=""):
42
+ if system:
43
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
44
+ try:
45
+ for his in history:
46
+ if his["role"] == "user":
47
+ his["content"] = self.chat_prompt(his["content"], image)
48
+
49
+ response = self.client.chat.completions.create(
50
+ model=self.model_name,
51
+ messages=history,
52
+ max_tokens=gen_conf.get("max_tokens", 1000),
53
+ temperature=gen_conf.get("temperature", 0.3),
54
+ top_p=gen_conf.get("top_p", 0.7)
55
+ )
56
+ return response.choices[0].message.content.strip(), response.usage.total_tokens
57
+ except Exception as e:
58
+ return "**ERROR**: " + str(e), 0
59
+
60
+ def chat_streamly(self, system, history, gen_conf, image=""):
61
+ if system:
62
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
63
+
64
+ ans = ""
65
+ tk_count = 0
66
+ try:
67
+ for his in history:
68
+ if his["role"] == "user":
69
+ his["content"] = self.chat_prompt(his["content"], image)
70
 
71
+ response = self.client.chat.completions.create(
72
+ model=self.model_name,
73
+ messages=history,
74
+ max_tokens=gen_conf.get("max_tokens", 1000),
75
+ temperature=gen_conf.get("temperature", 0.3),
76
+ top_p=gen_conf.get("top_p", 0.7),
77
+ stream=True
78
+ )
79
+ for resp in response:
80
+ if not resp.choices[0].delta.content: continue
81
+ delta = resp.choices[0].delta.content
82
+ ans += delta
83
+ if resp.choices[0].finish_reason == "length":
84
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
85
+ [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
86
+ tk_count = resp.usage.total_tokens
87
+ if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
88
+ yield ans
89
+ except Exception as e:
90
+ yield ans + "\n**ERROR**: " + str(e)
91
+
92
+ yield tk_count
93
+
94
  def image2base64(self, image):
95
  if isinstance(image, bytes):
96
  return base64.b64encode(image).decode("utf-8")
 
122
  }
123
  ]
124
 
125
+ def chat_prompt(self, text, b64):
126
+ return [
127
+ {
128
+ "type": "image_url",
129
+ "image_url": {
130
+ "url": f"data:image/jpeg;base64,{b64}",
131
+ },
132
+ },
133
+ {
134
+ "type": "text",
135
+ "text": text
136
+ },
137
+ ]
138
+
139
+
140
 
141
  class GptV4(Base):
142
  def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
 
209
  }
210
  ]
211
 
212
+ def chat_prompt(self, text, b64):
213
+ return [
214
+ {"image": f"{b64}"},
215
+ {"text": text},
216
+ ]
217
+
218
  def describe(self, image, max_tokens=300):
219
  from http import HTTPStatus
220
  from dashscope import MultiModalConversation
 
224
  return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
225
  return response.message, 0
226
 
227
+ def chat(self, system, history, gen_conf, image=""):
228
+ from http import HTTPStatus
229
+ from dashscope import MultiModalConversation
230
+ if system:
231
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
232
+
233
+ for his in history:
234
+ if his["role"] == "user":
235
+ his["content"] = self.chat_prompt(his["content"], image)
236
+ response = MultiModalConversation.call(model=self.model_name, messages=history,
237
+ max_tokens=gen_conf.get("max_tokens", 1000),
238
+ temperature=gen_conf.get("temperature", 0.3),
239
+ top_p=gen_conf.get("top_p", 0.7))
240
+
241
+ ans = ""
242
+ tk_count = 0
243
+ if response.status_code == HTTPStatus.OK:
244
+ ans += response.output.choices[0]['message']['content']
245
+ tk_count += response.usage.total_tokens
246
+ if response.output.choices[0].get("finish_reason", "") == "length":
247
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
248
+ [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
249
+ return ans, tk_count
250
+
251
+ return "**ERROR**: " + response.message, tk_count
252
+
253
+ def chat_streamly(self, system, history, gen_conf, image=""):
254
+ from http import HTTPStatus
255
+ from dashscope import MultiModalConversation
256
+ if system:
257
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
258
+
259
+ for his in history:
260
+ if his["role"] == "user":
261
+ his["content"] = self.chat_prompt(his["content"], image)
262
+
263
+ ans = ""
264
+ tk_count = 0
265
+ try:
266
+ response = MultiModalConversation.call(model=self.model_name, messages=history,
267
+ max_tokens=gen_conf.get("max_tokens", 1000),
268
+ temperature=gen_conf.get("temperature", 0.3),
269
+ top_p=gen_conf.get("top_p", 0.7),
270
+ stream=True)
271
+ for resp in response:
272
+ if resp.status_code == HTTPStatus.OK:
273
+ ans = resp.output.choices[0]['message']['content']
274
+ tk_count = resp.usage.total_tokens
275
+ if resp.output.choices[0].get("finish_reason", "") == "length":
276
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
277
+ [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
278
+ yield ans
279
+ else:
280
+ yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
281
+ "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
282
+ except Exception as e:
283
+ yield ans + "\n**ERROR**: " + str(e)
284
+
285
+ yield tk_count
286
+
287
 
288
  class Zhipu4V(Base):
289
  def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
 
301
  )
302
  return res.choices[0].message.content.strip(), res.usage.total_tokens
303
 
304
+ def chat(self, system, history, gen_conf, image=""):
305
+ if system:
306
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
307
+ try:
308
+ for his in history:
309
+ if his["role"] == "user":
310
+ his["content"] = self.chat_prompt(his["content"], image)
311
+
312
+ response = self.client.chat.completions.create(
313
+ model=self.model_name,
314
+ messages=history,
315
+ max_tokens=gen_conf.get("max_tokens", 1000),
316
+ temperature=gen_conf.get("temperature", 0.3),
317
+ top_p=gen_conf.get("top_p", 0.7)
318
+ )
319
+ return response.choices[0].message.content.strip(), response.usage.total_tokens
320
+ except Exception as e:
321
+ return "**ERROR**: " + str(e), 0
322
+
323
+ def chat_streamly(self, system, history, gen_conf, image=""):
324
+ if system:
325
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
326
+
327
+ ans = ""
328
+ tk_count = 0
329
+ try:
330
+ for his in history:
331
+ if his["role"] == "user":
332
+ his["content"] = self.chat_prompt(his["content"], image)
333
+
334
+ response = self.client.chat.completions.create(
335
+ model=self.model_name,
336
+ messages=history,
337
+ max_tokens=gen_conf.get("max_tokens", 1000),
338
+ temperature=gen_conf.get("temperature", 0.3),
339
+ top_p=gen_conf.get("top_p", 0.7),
340
+ stream=True
341
+ )
342
+ for resp in response:
343
+ if not resp.choices[0].delta.content: continue
344
+ delta = resp.choices[0].delta.content
345
+ ans += delta
346
+ if resp.choices[0].finish_reason == "length":
347
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
348
+ [ans]) else "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
349
+ tk_count = resp.usage.total_tokens
350
+ if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
351
+ yield ans
352
+ except Exception as e:
353
+ yield ans + "\n**ERROR**: " + str(e)
354
+
355
+ yield tk_count
356
+
357
 
358
  class OllamaCV(Base):
359
  def __init__(self, key, model_name, lang="Chinese", **kwargs):
 
376
  except Exception as e:
377
  return "**ERROR**: " + str(e), 0
378
 
379
+ def chat(self, system, history, gen_conf, image=""):
380
+ if system:
381
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
382
+
383
+ try:
384
+ for his in history:
385
+ if his["role"] == "user":
386
+ his["images"] = [image]
387
+ options = {}
388
+ if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
389
+ if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
390
+ if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
391
+ if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
392
+ if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
393
+ response = self.client.chat(
394
+ model=self.model_name,
395
+ messages=history,
396
+ options=options,
397
+ keep_alive=-1
398
+ )
399
+
400
+ ans = response["message"]["content"].strip()
401
+ return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
402
+ except Exception as e:
403
+ return "**ERROR**: " + str(e), 0
404
+
405
+ def chat_streamly(self, system, history, gen_conf, image=""):
406
+ if system:
407
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
408
+
409
+ for his in history:
410
+ if his["role"] == "user":
411
+ his["images"] = [image]
412
+ options = {}
413
+ if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
414
+ if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
415
+ if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
416
+ if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
417
+ if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
418
+ ans = ""
419
+ try:
420
+ response = self.client.chat(
421
+ model=self.model_name,
422
+ messages=history,
423
+ stream=True,
424
+ options=options,
425
+ keep_alive=-1
426
+ )
427
+ for resp in response:
428
+ if resp["done"]:
429
+ yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
430
+ ans += resp["message"]["content"]
431
+ yield ans
432
+ except Exception as e:
433
+ yield ans + "\n**ERROR**: " + str(e)
434
+ yield 0
435
+
436
 
437
  class LocalAICV(Base):
438
  def __init__(self, key, model_name, base_url, lang="Chinese"):
 
481
 
482
  class GeminiCV(Base):
483
  def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
484
+ from google.generativeai import client, GenerativeModel, GenerationConfig
485
  client.configure(api_key=key)
486
  _client = client.get_default_generative_client()
487
  self.model_name = model_name
 
503
  )
504
  return res.text,res.usage_metadata.total_token_count
505
 
506
+ def chat(self, system, history, gen_conf, image=""):
507
+ if system:
508
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
509
+ try:
510
+ for his in history:
511
+ if his["role"] == "assistant":
512
+ his["role"] = "model"
513
+ his["parts"] = [his["content"]]
514
+ his.pop("content")
515
+ if his["role"] == "user":
516
+ his["parts"] = [his["content"]]
517
+ his.pop("content")
518
+ history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
519
+
520
+ response = self.model.generate_content(history, generation_config=GenerationConfig(
521
+ max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
522
+ top_p=gen_conf.get("top_p", 0.7)))
523
+
524
+ ans = response.text
525
+ return ans, response.usage_metadata.total_token_count
526
+ except Exception as e:
527
+ return "**ERROR**: " + str(e), 0
528
+
529
+ def chat_streamly(self, system, history, gen_conf, image=""):
530
+ if system:
531
+ history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
532
+
533
+ ans = ""
534
+ tk_count = 0
535
+ try:
536
+ for his in history:
537
+ if his["role"] == "assistant":
538
+ his["role"] = "model"
539
+ his["parts"] = [his["content"]]
540
+ his.pop("content")
541
+ if his["role"] == "user":
542
+ his["parts"] = [his["content"]]
543
+ his.pop("content")
544
+ history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
545
+
546
+ response = self.model.generate_content(history, generation_config=GenerationConfig(
547
+ max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
548
+ top_p=gen_conf.get("top_p", 0.7)), stream=True)
549
+
550
+ for resp in response:
551
+ if not resp.text: continue
552
+ ans += resp.text
553
+ yield ans
554
+ except Exception as e:
555
+ yield ans + "\n**ERROR**: " + str(e)
556
+
557
+ yield response._chunks[-1].usage_metadata.total_token_count
558
+
559
 
560
  class OpenRouterCV(Base):
561
  def __init__(
web/src/components/llm-setting-items/index.tsx CHANGED
@@ -46,7 +46,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => {
46
  {...formItemLayout}
47
  rules={[{ required: true, message: t('modelMessage') }]}
48
  >
49
- <Select options={modelOptions[LlmModelType.Chat]} showSearch />
50
  </Form.Item>
51
  <Divider></Divider>
52
  <Form.Item
 
46
  {...formItemLayout}
47
  rules={[{ required: true, message: t('modelMessage') }]}
48
  >
49
+ <Select options={[...modelOptions[LlmModelType.Chat], ...modelOptions[LlmModelType.Image2text],]} showSearch/>
50
  </Form.Item>
51
  <Divider></Divider>
52
  <Form.Item