Husnain commited on
Commit
e7c04bf
1 Parent(s): 8593e5c

💎 [Feature] New auth_api_key, and catch Exception and response

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +61 -47
apis/chat_api.py CHANGED
@@ -7,7 +7,7 @@ import uvicorn
7
  from pathlib import Path
8
  from typing import Union
9
 
10
- from fastapi import FastAPI, Depends
11
  from fastapi.responses import HTMLResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel, Field
@@ -15,10 +15,12 @@ from sse_starlette.sse import EventSourceResponse, ServerSentEvent
15
  from tclogger import logger
16
 
17
  from constants.models import AVAILABLE_MODELS_DICTS, PRO_MODELS
18
- from constants.envs import CONFIG
 
19
 
20
  from messagers.message_composer import MessageComposer
21
  from mocks.stream_chat_mocker import stream_chat_mock
 
22
  from networks.huggingface_streamer import HuggingfaceStreamer
23
  from networks.huggingchat_streamer import HuggingchatStreamer
24
  from networks.openai_streamer import OpenaiStreamer
@@ -38,26 +40,31 @@ class ChatAPIApp:
38
  return {"object": "list", "data": AVAILABLE_MODELS_DICTS}
39
 
40
  def extract_api_key(
41
- credentials: HTTPAuthorizationCredentials = Depends(
42
- HTTPBearer(auto_error=False)
43
- ),
44
  ):
45
  api_key = None
46
  if credentials:
47
  api_key = credentials.credentials
48
- else:
49
- api_key = os.getenv("HF_TOKEN")
50
 
51
- if api_key:
52
- if api_key.startswith("hf_"):
53
- return api_key
54
- else:
55
- logger.warn(f"Invalid HF Token!")
56
- else:
57
- logger.warn("Not provide HF Token!")
58
- return None
 
 
 
 
 
 
59
 
60
  class ChatCompletionsPostItem(BaseModel):
 
61
  model: str = Field(
62
  default="nous-mixtral-8x7b",
63
  description="(str) `nous-mixtral-8x7b`",
@@ -90,38 +97,45 @@ class ChatAPIApp:
90
  def chat_completions(
91
  self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
92
  ):
93
- if item.model == "gpt-3.5-turbo":
94
- streamer = OpenaiStreamer()
95
- stream_response = streamer.chat_response(messages=item.messages)
96
- elif item.model in PRO_MODELS:
97
- streamer = HuggingchatStreamer(model=item.model)
98
- stream_response = streamer.chat_response(
99
- messages=item.messages,
100
- )
101
- else:
102
- streamer = HuggingfaceStreamer(model=item.model)
103
- composer = MessageComposer(model=item.model)
104
- composer.merge(messages=item.messages)
105
- stream_response = streamer.chat_response(
106
- prompt=composer.merged_str,
107
- temperature=item.temperature,
108
- top_p=item.top_p,
109
- max_new_tokens=item.max_tokens,
110
- api_key=api_key,
111
- use_cache=item.use_cache,
112
- )
113
-
114
- if item.stream:
115
- event_source_response = EventSourceResponse(
116
- streamer.chat_return_generator(stream_response),
117
- media_type="text/event-stream",
118
- ping=2000,
119
- ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
120
- )
121
- return event_source_response
122
- else:
123
- data_response = streamer.chat_return_dict(stream_response)
124
- return data_response
 
 
 
 
 
 
 
125
 
126
  def get_readme(self):
127
  readme_path = Path(__file__).parents[1] / "README.md"
 
7
  from pathlib import Path
8
  from typing import Union
9
 
10
+ from fastapi import FastAPI, Depends, HTTPException
11
  from fastapi.responses import HTMLResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel, Field
 
15
  from tclogger import logger
16
 
17
  from constants.models import AVAILABLE_MODELS_DICTS, PRO_MODELS
18
+ from constants.envs import CONFIG, SECRETS
19
+ from networks.exceptions import HfApiException, INVALID_API_KEY_ERROR
20
 
21
  from messagers.message_composer import MessageComposer
22
  from mocks.stream_chat_mocker import stream_chat_mock
23
+
24
  from networks.huggingface_streamer import HuggingfaceStreamer
25
  from networks.huggingchat_streamer import HuggingchatStreamer
26
  from networks.openai_streamer import OpenaiStreamer
 
40
  return {"object": "list", "data": AVAILABLE_MODELS_DICTS}
41
 
42
  def extract_api_key(
43
+ credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()),
 
 
44
  ):
45
  api_key = None
46
  if credentials:
47
  api_key = credentials.credentials
48
+ env_api_key = SECRETS["HF_LLM_API_KEY"]
49
+ return api_key
50
 
51
+ def auth_api_key(self, api_key: str):
52
+ env_api_key = SECRETS["HF_LLM_API_KEY"]
53
+
54
+ # require no api_key
55
+ if not env_api_key:
56
+ return None
57
+ # user provides HF_TOKEN
58
+ if api_key and api_key.startswith("hf_"):
59
+ return api_key
60
+ # user provides correct API_KEY
61
+ if str(api_key) == str(env_api_key):
62
+ return None
63
+
64
+ raise INVALID_API_KEY_ERROR
65
 
66
  class ChatCompletionsPostItem(BaseModel):
67
+
68
  model: str = Field(
69
  default="nous-mixtral-8x7b",
70
  description="(str) `nous-mixtral-8x7b`",
 
97
  def chat_completions(
98
  self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
99
  ):
100
+ try:
101
+ api_key = self.auth_api_key(api_key)
102
+
103
+ if item.model == "gpt-3.5-turbo":
104
+ streamer = OpenaiStreamer()
105
+ stream_response = streamer.chat_response(messages=item.messages)
106
+ elif item.model in PRO_MODELS:
107
+ streamer = HuggingchatStreamer(model=item.model)
108
+ stream_response = streamer.chat_response(
109
+ messages=item.messages,
110
+ )
111
+ else:
112
+ streamer = HuggingfaceStreamer(model=item.model)
113
+ composer = MessageComposer(model=item.model)
114
+ composer.merge(messages=item.messages)
115
+ stream_response = streamer.chat_response(
116
+ prompt=composer.merged_str,
117
+ temperature=item.temperature,
118
+ top_p=item.top_p,
119
+ max_new_tokens=item.max_tokens,
120
+ api_key=api_key,
121
+ use_cache=item.use_cache,
122
+ )
123
+
124
+ if item.stream:
125
+ event_source_response = EventSourceResponse(
126
+ streamer.chat_return_generator(stream_response),
127
+ media_type="text/event-stream",
128
+ ping=2000,
129
+ ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
130
+ )
131
+ return event_source_response
132
+ else:
133
+ data_response = streamer.chat_return_dict(stream_response)
134
+ return data_response
135
+ except HfApiException as e:
136
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
137
+ except Exception as e:
138
+ raise HTTPException(status_code=500, detail=str(e))
139
 
140
  def get_readme(self):
141
  readme_path = Path(__file__).parents[1] / "README.md"