radinhas commited on
Commit
0aebbbc
·
1 Parent(s): 9131fdd

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +22 -33
apis/chat_api.py CHANGED
@@ -4,6 +4,8 @@ import sys
4
  import json
5
 
6
  from fastapi import FastAPI
 
 
7
  from pydantic import BaseModel, Field
8
  from sse_starlette.sse import EventSourceResponse
9
  from utils.logger import logger
@@ -27,52 +29,39 @@ class ChatAPIApp:
27
  return self.available_models
28
 
29
  class ChatCompletionsPostItem(BaseModel):
30
- model: str = Field(
31
- default="mixtral-8x7b",
32
- description="(str) `mixtral-8x7b`",
33
  )
34
- messages: list = Field(
35
- default=[{"role": "user", "content": "Hello, who are you?"}],
36
- description="(list) Messages",
37
  )
38
- temperature: float = Field(
39
- default=0.01,
40
- description="(float) Temperature",
41
- )
42
- max_tokens: int = Field(
43
- default=8192,
44
- description="(int) Max tokens",
45
- )
46
- stream: bool = Field(
47
- default=True,
48
- description="(bool) Stream",
49
  )
50
 
51
  def chat_completions(self, item: ChatCompletionsPostItem):
52
- streamer = MessageStreamer(model=item.model)
53
- composer = MessageComposer(model=item.model)
54
- composer.merge(messages=item.messages)
55
- return EventSourceResponse(
56
- streamer.chat(
57
- prompt=composer.merged_str,
58
- temperature=item.temperature,
59
- max_new_tokens=item.max_tokens,
60
- stream=item.stream,
61
- yield_output=True,
62
- ),
63
- media_type="text/event-stream",
64
- )
65
 
66
  def setup_routes(self):
67
  for prefix in ["", "/v1"]:
68
  self.app.get(
69
  prefix + "/models",
70
- summary="Get available models",
71
  )(self.get_available_models)
72
 
73
  self.app.post(
74
- prefix + "/chat/completions",
75
- summary="Chat completions in conversation session",
76
  )(self.chat_completions)
77
 
78
 
 
4
  import json
5
 
6
  from fastapi import FastAPI
7
+ from fastapi.encoders import jsonable_encoder
8
+ from fastapi.responses import JSONResponse
9
  from pydantic import BaseModel, Field
10
  from sse_starlette.sse import EventSourceResponse
11
  from utils.logger import logger
 
29
  return self.available_models
30
 
31
  class ChatCompletionsPostItem(BaseModel):
32
+ from_language: str = Field(
33
+ default="auto",
34
+ description="(str) `Detect`",
35
  )
36
+ to_language: str = Field(
37
+ default="en",
38
+ description="(str) `en`",
39
  )
40
+ text: str = Field(
41
+ default="Hello",
42
+ description="(str) `Text for translate`",
 
 
 
 
 
 
 
 
43
  )
44
 
45
  def chat_completions(self, item: ChatCompletionsPostItem):
46
+ item_response = {
47
+ "from_language": item.from_language,
48
+ "to_language": item.to_language,
49
+ "text": item.text,
50
+ "translate": ""
51
+ }
52
+ json_compatible_item_data = jsonable_encoder(item_response)
53
+ return JSONResponse(content=json_compatible_item_data)
 
 
 
 
 
54
 
55
  def setup_routes(self):
56
  for prefix in ["", "/v1"]:
57
  self.app.get(
58
  prefix + "/models",
59
+ summary="Get available languages",
60
  )(self.get_available_models)
61
 
62
  self.app.post(
63
+ prefix + "/translate",
64
+ summary="translate text",
65
  )(self.chat_completions)
66
 
67