Yarik commited on
Commit
8d8934a
1 Parent(s): 199093c

update file

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +81 -136
apis/chat_api.py CHANGED
@@ -1,88 +1,97 @@
1
  import argparse
2
- import markdown2
3
  import os
4
  import sys
5
  import uvicorn
6
-
7
  from pathlib import Path
8
- from fastapi import FastAPI, Depends
9
- from fastapi.responses import HTMLResponse
10
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from pydantic import BaseModel, Field
12
  from typing import Union
 
13
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
14
  from utils.logger import logger
15
  from networks.message_streamer import MessageStreamer
16
  from messagers.message_composer import MessageComposer
17
  from mocks.stream_chat_mocker import stream_chat_mock
18
 
 
 
 
19
 
20
  class ChatAPIApp:
 
 
 
 
 
21
  def __init__(self):
22
- self.app = FastAPI(
23
- docs_url="/",
24
- title="HuggingFace LLM API",
25
- swagger_ui_parameters={"defaultModelsExpandDepth": -1},
26
- version="1.0",
27
- )
28
  self.setup_routes()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def get_available_models(self):
31
- # https://platform.openai.com/docs/api-reference/models/list
32
- # ANCHOR[id=available-models]: Available models
33
  self.available_models = {
34
  "object": "list",
35
  "data": [
36
- {
37
- "id": "mixtral-8x7b",
38
- "description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
39
- "object": "model",
40
- "created": 1700000000,
41
- "owned_by": "mistralai",
42
- },
43
- {
44
- "id": "mistral-7b",
45
- "description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
46
- "object": "model",
47
- "created": 1700000000,
48
- "owned_by": "mistralai",
49
- },
50
- {
51
- "id": "nous-mixtral-8x7b",
52
- "description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
53
- "object": "model",
54
- "created": 1700000000,
55
- "owned_by": "NousResearch",
56
- },
57
- {
58
- "id": "zephyr-7b-beta",
59
- "description": "[HuggingFaceH4/zephyr-7b-beta]: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta",
60
- "object": "model",
61
- "created": 1700000000,
62
- "owned_by": "TheBloke",
63
- },
64
- {
65
- "id": "starchat2-15b-v0.1",
66
- "description": "[HuggingFaceH4/starchat2-15b-v0.1]: https://huggingface.co/HuggingFaceH4/starchat2-15b-v0.1",
67
- "object": "model",
68
- "created": 1700000000,
69
- "owned_by": "TheBloke",
70
- },
71
- ],
72
  }
73
  return self.available_models
74
 
75
- def extract_api_key(
76
- credentials: HTTPAuthorizationCredentials = Depends(
77
- HTTPBearer(auto_error=False)
78
- ),
79
- ):
80
  api_key = None
81
  if credentials:
82
  api_key = credentials.credentials
83
  else:
84
  api_key = os.getenv("XCHE_TOKEN")
85
-
86
  if api_key:
87
  if api_key.startswith("hf_"):
88
  return api_key
@@ -93,43 +102,19 @@ class ChatAPIApp:
93
  return None
94
 
95
  class ChatCompletionsPostItem(BaseModel):
96
- model: str = Field(
97
- default="mixtral-8x7b",
98
- description="(str) `mixtral-8x7b`",
99
- )
100
- messages: list = Field(
101
- default=[{"role": "user", "content": "Hello, who are you?"}],
102
- description="(list) Messages",
103
- )
104
- temperature: Union[float, None] = Field(
105
- default=0.5,
106
- description="(float) Temperature",
107
- )
108
- top_p: Union[float, None] = Field(
109
- default=0.95,
110
- description="(float) top p",
111
- )
112
- max_tokens: Union[int, None] = Field(
113
- default=-1,
114
- description="(int) Max tokens",
115
- )
116
- use_cache: bool = Field(
117
- default=False,
118
- description="(bool) Use cache",
119
- )
120
- stream: bool = Field(
121
- default=True,
122
- description="(bool) Stream",
123
- )
124
-
125
- def chat_completions(
126
- self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
127
- ):
128
  streamer = MessageStreamer(model=item.model)
129
  composer = MessageComposer(model=item.model)
130
  composer.merge(messages=item.messages)
131
  # streamer.chat = stream_chat_mock
132
-
133
  stream_response = streamer.chat_response(
134
  prompt=composer.merged_str,
135
  temperature=item.temperature,
@@ -154,9 +139,7 @@ class ChatAPIApp:
154
  readme_path = Path(__file__).parents[1] / "README.md"
155
  with open(readme_path, "r", encoding="utf-8") as rf:
156
  readme_str = rf.read()
157
- readme_html = markdown2.markdown(
158
- readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
159
- )
160
  return readme_html
161
 
162
  def setup_routes(self):
@@ -166,55 +149,20 @@ class ChatAPIApp:
166
  else:
167
  include_in_schema = False
168
 
169
- self.app.get(
170
- prefix + "/models",
171
- summary="Get available models",
172
- include_in_schema=include_in_schema,
173
- )(self.get_available_models)
174
-
175
- self.app.post(
176
- prefix + "/chat/completions",
177
- summary="Chat completions in conversation session",
178
- include_in_schema=include_in_schema,
179
- )(self.chat_completions)
180
- self.app.get(
181
- "/readme",
182
- summary="README of HF LLM API",
183
- response_class=HTMLResponse,
184
- include_in_schema=False,
185
- )(self.get_readme)
186
 
 
 
 
187
 
188
  class ArgParser(argparse.ArgumentParser):
189
  def __init__(self, *args, **kwargs):
190
  super(ArgParser, self).__init__(*args, **kwargs)
191
-
192
- self.add_argument(
193
- "-s",
194
- "--server",
195
- type=str,
196
- default="0.0.0.0",
197
- help="Server IP for HF LLM Chat API",
198
- )
199
- self.add_argument(
200
- "-p",
201
- "--port",
202
- type=int,
203
- default=7860,
204
- help="Server Port for HF LLM Chat API",
205
- )
206
-
207
- self.add_argument(
208
- "-d",
209
- "--dev",
210
- default=False,
211
- action="store_true",
212
- help="Run in dev mode",
213
- )
214
-
215
  self.args = self.parse_args(sys.argv[1:])
216
 
217
-
218
  app = ChatAPIApp().app
219
 
220
  if __name__ == "__main__":
@@ -223,6 +171,3 @@ if __name__ == "__main__":
223
  uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
224
  else:
225
  uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
226
-
227
- # python -m apis.chat_api # [Docker] on product mode
228
- # python -m apis.chat_api -d # [Dev] on develop mode
 
1
  import argparse
 
2
  import os
3
  import sys
4
  import uvicorn
5
+ import markdown2
6
  from pathlib import Path
7
+ from fastapi import FastAPI, Depends, HTTPException
8
+ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
9
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials, OAuth2PasswordBearer, OAuth2PasswordRequestForm
10
  from pydantic import BaseModel, Field
11
  from typing import Union
12
+ from passlib.context import CryptContext
13
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
14
  from utils.logger import logger
15
  from networks.message_streamer import MessageStreamer
16
  from messagers.message_composer import MessageComposer
17
  from mocks.stream_chat_mocker import stream_chat_mock
18
 
19
+ class Auth(BaseModel):
20
+ api_key: str
21
+ password: str
22
 
23
  class ChatAPIApp:
24
+
25
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
26
+ # Password hashing context
27
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
28
+
29
  def __init__(self):
30
+ self.app = FastAPI(docs_url=None, redoc_url=None)
 
 
 
 
 
31
  self.setup_routes()
32
+ self.api_key = os.getenv("XCHE_API_KEY")
33
+ self.password = os.getenv("XCHE_PASSWORD")
34
+ self.fake_data_db = {
35
+ self.api_key: {
36
+ "api_key": self.api_key,
37
+ "password": self.get_password_hash(self.password) # Pre-hashed password
38
+ }
39
+ }
40
+
41
+ def get_password_hash(self, password):
42
+ return self.pwd_context.hash(password)
43
+
44
+ def verify_password(self, plain_password, hashed_password):
45
+ return self.pwd_context.verify(plain_password, hashed_password)
46
+
47
+ def get_api_key(self, db, api_key: str):
48
+ if api_key in db:
49
+ api_dict = db[api_key]
50
+ return Auth(**api_dict)
51
+
52
+ def authenticate(self, fake_db, api_key: str, password: str):
53
+ api_data = self.get_api_key(fake_db, api_key)
54
+ if not api_data:
55
+ return False
56
+ if not self.verify_password(password, api_data.password):
57
+ return False
58
+ return api_data
59
+
60
+ async def login(self, form_data: OAuth2PasswordRequestForm = Depends()):
61
+ api_data = self.authenticate(self.fake_data_db, form_data.username, form_data.password)
62
+ if not api_data:
63
+ raise HTTPException(
64
+ status_code=400,
65
+ detail="Incorrect API KEY or Password",
66
+ headers={"WWW-Authenticate": "Bearer"},
67
+ )
68
+ return {"access_token": api_data.api_key, "token_type": "bearer"}
69
+
70
+ def check_api_token(self, token: str = Depends(oauth2_scheme)):
71
+ api_data = self.get_api_key(self.fake_data_db, token)
72
+ if not api_data:
73
+ raise HTTPException(status_code=403, detail="Invalid or missing API Key")
74
+ return api_data
75
 
76
  def get_available_models(self):
 
 
77
  self.available_models = {
78
  "object": "list",
79
  "data": [
80
+ {"id": "mixtral-8x7b", "description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", "object": "model", "created": 1700000000, "owned_by": "mistralai"},
81
+ {"id": "mistral-7b", "description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", "object": "model", "created": 1700000000, "owned_by": "mistralai"},
82
+ {"id": "nous-mixtral-8x7b", "description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", "object": "model", "created": 1700000000, "owned_by": "NousResearch"},
83
+ {"id": "zephyr-7b-beta", "description": "[HuggingFaceH4/zephyr-7b-beta]: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta", "object": "model", "created": 1700000000, "owned_by": "TheBloke"},
84
+ {"id": "starchat2-15b-v0.1", "description": "[HuggingFaceH4/starchat2-15b-v0.1]: https://huggingface.co/HuggingFaceH4/starchat2-15b-v0.1", "object": "model", "created": 1700000000, "owned_by": "TheBloke"},
85
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  }
87
  return self.available_models
88
 
89
+ def extract_api_key(self, credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False))):
 
 
 
 
90
  api_key = None
91
  if credentials:
92
  api_key = credentials.credentials
93
  else:
94
  api_key = os.getenv("XCHE_TOKEN")
 
95
  if api_key:
96
  if api_key.startswith("hf_"):
97
  return api_key
 
102
  return None
103
 
104
  class ChatCompletionsPostItem(BaseModel):
105
+ model: str = Field(default="mixtral-8x7b", description="(str) `mixtral-8x7b`")
106
+ messages: list = Field(default=[{"role": "user", "content": "Hello, who are you?"}], description="(list) Messages")
107
+ temperature: Union[float, None] = Field(default=0.5, description="(float) Temperature")
108
+ top_p: Union[float, None] = Field(default=0.95, description="(float) top p")
109
+ max_tokens: Union[int, None] = Field(default=-1, description="(int) Max tokens")
110
+ use_cache: bool = Field(default=False, description="(bool) Use cache")
111
+ stream: bool = Field(default=True, description="(bool) Stream")
112
+
113
+ def chat_completions(self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  streamer = MessageStreamer(model=item.model)
115
  composer = MessageComposer(model=item.model)
116
  composer.merge(messages=item.messages)
117
  # streamer.chat = stream_chat_mock
 
118
  stream_response = streamer.chat_response(
119
  prompt=composer.merged_str,
120
  temperature=item.temperature,
 
139
  readme_path = Path(__file__).parents[1] / "README.md"
140
  with open(readme_path, "r", encoding="utf-8") as rf:
141
  readme_str = rf.read()
142
+ readme_html = markdown2.markdown(readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"])
 
 
143
  return readme_html
144
 
145
  def setup_routes(self):
 
149
  else:
150
  include_in_schema = False
151
 
152
+ self.app.get(prefix + "/models", summary="Get available models", include_in_schema=include_in_schema)(self.get_available_models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ self.app.post(prefix + "/chat/completions", summary="Chat completions in conversation session", include_in_schema=include_in_schema)(self.chat_completions)
155
+ self.app.get("/readme", summary="README of HF LLM API", response_class=HTMLResponse, include_in_schema=False)(self.get_readme)
156
+ self.app.post("/token", include_in_schema=False)(self.login)
157
 
158
  class ArgParser(argparse.ArgumentParser):
159
  def __init__(self, *args, **kwargs):
160
  super(ArgParser, self).__init__(*args, **kwargs)
161
+ self.add_argument("-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API")
162
+ self.add_argument("-p", "--port", type=int, default=7860, help="Server Port for HF LLM Chat API")
163
+ self.add_argument("-d", "--dev", default=False, action="store_true", help="Run in dev mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  self.args = self.parse_args(sys.argv[1:])
165
 
 
166
  app = ChatAPIApp().app
167
 
168
  if __name__ == "__main__":
 
171
  uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
172
  else:
173
  uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)