Hansimov commited on
Commit
06e3150
1 Parent(s): 62d5db7

:gem: [Feature] ChatAPIApp: Enable chat with pro models

Browse files
Files changed (2) hide show
  1. apis/chat_api.py +7 -1
  2. constants/models.py +2 -0
apis/chat_api.py CHANGED
@@ -14,12 +14,13 @@ from pydantic import BaseModel, Field
14
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
15
  from tclogger import logger
16
 
17
- from constants.models import AVAILABLE_MODELS_DICTS
18
  from constants.envs import CONFIG
19
 
20
  from messagers.message_composer import MessageComposer
21
  from mocks.stream_chat_mocker import stream_chat_mock
22
  from networks.huggingface_streamer import HuggingfaceStreamer
 
23
  from networks.openai_streamer import OpenaiStreamer
24
 
25
 
@@ -92,6 +93,11 @@ class ChatAPIApp:
92
  if item.model == "gpt-3.5-turbo":
93
  streamer = OpenaiStreamer()
94
  stream_response = streamer.chat_response(messages=item.messages)
 
 
 
 
 
95
  else:
96
  streamer = HuggingfaceStreamer(model=item.model)
97
  composer = MessageComposer(model=item.model)
 
14
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
15
  from tclogger import logger
16
 
17
+ from constants.models import AVAILABLE_MODELS_DICTS, PRO_MODELS
18
  from constants.envs import CONFIG
19
 
20
  from messagers.message_composer import MessageComposer
21
  from mocks.stream_chat_mocker import stream_chat_mock
22
  from networks.huggingface_streamer import HuggingfaceStreamer
23
+ from networks.huggingchat_streamer import HuggingchatStreamer
24
  from networks.openai_streamer import OpenaiStreamer
25
 
26
 
 
93
  if item.model == "gpt-3.5-turbo":
94
  streamer = OpenaiStreamer()
95
  stream_response = streamer.chat_response(messages=item.messages)
96
+ elif item.model in PRO_MODELS:
97
+ streamer = HuggingchatStreamer(model=item.model)
98
+ stream_response = streamer.chat_response(
99
+ messages=item.messages,
100
+ )
101
  else:
102
  streamer = HuggingfaceStreamer(model=item.model)
103
  composer = MessageComposer(model=item.model)
constants/models.py CHANGED
@@ -12,6 +12,8 @@ MODEL_MAP = {
12
 
13
  AVAILABLE_MODELS = list(MODEL_MAP.keys())
14
 
 
 
15
  STOP_SEQUENCES_MAP = {
16
  "mixtral-8x7b": "</s>",
17
  "nous-mixtral-8x7b": "<|im_end|>",
 
12
 
13
  AVAILABLE_MODELS = list(MODEL_MAP.keys())
14
 
15
+ PRO_MODELS = ["command-r-plus", "llama3-70b", "zephyr-141b"]
16
+
17
  STOP_SEQUENCES_MAP = {
18
  "mixtral-8x7b": "</s>",
19
  "nous-mixtral-8x7b": "<|im_end|>",