Hansimov commited on
Commit
2da6968
1 Parent(s): d2b20f2

:gem: [Feature] Support call hf api with api_key via HTTP Bearer

Browse files
Files changed (2) hide show
  1. apis/chat_api.py +16 -2
  2. networks/message_streamer.py +8 -0
apis/chat_api.py CHANGED
@@ -2,7 +2,8 @@ import argparse
2
  import uvicorn
3
  import sys
4
 
5
- from fastapi import FastAPI
 
6
  from pydantic import BaseModel, Field
7
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
8
  from utils.logger import logger
@@ -38,6 +39,16 @@ class ChatAPIApp:
38
  ]
39
  return self.available_models
40
 
 
 
 
 
 
 
 
 
 
 
41
  class ChatCompletionsPostItem(BaseModel):
42
  model: str = Field(
43
  default="mixtral-8x7b",
@@ -60,7 +71,9 @@ class ChatAPIApp:
60
  description="(bool) Stream",
61
  )
62
 
63
- def chat_completions(self, item: ChatCompletionsPostItem):
 
 
64
  streamer = MessageStreamer(model=item.model)
65
  composer = MessageComposer(model=item.model)
66
  composer.merge(messages=item.messages)
@@ -70,6 +83,7 @@ class ChatAPIApp:
70
  prompt=composer.merged_str,
71
  temperature=item.temperature,
72
  max_new_tokens=item.max_tokens,
 
73
  )
74
  if item.stream:
75
  event_source_response = EventSourceResponse(
 
2
  import uvicorn
3
  import sys
4
 
5
+ from fastapi import FastAPI, Depends
6
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
  from pydantic import BaseModel, Field
8
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
9
  from utils.logger import logger
 
39
  ]
40
  return self.available_models
41
 
42
+ def extract_api_key(
43
+ credentials: HTTPAuthorizationCredentials = Depends(
44
+ HTTPBearer(auto_error=False)
45
+ ),
46
+ ):
47
+ if credentials:
48
+ return credentials.credentials
49
+ else:
50
+ return None
51
+
52
  class ChatCompletionsPostItem(BaseModel):
53
  model: str = Field(
54
  default="mixtral-8x7b",
 
71
  description="(bool) Stream",
72
  )
73
 
74
+ def chat_completions(
75
+ self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
76
+ ):
77
  streamer = MessageStreamer(model=item.model)
78
  composer = MessageComposer(model=item.model)
79
  composer.merge(messages=item.messages)
 
83
  prompt=composer.merged_str,
84
  temperature=item.temperature,
85
  max_new_tokens=item.max_tokens,
86
+ api_key=api_key,
87
  )
88
  if item.stream:
89
  event_source_response = EventSourceResponse(
networks/message_streamer.py CHANGED
@@ -36,6 +36,7 @@ class MessageStreamer:
36
  prompt: str = None,
37
  temperature: float = 0.01,
38
  max_new_tokens: int = 8192,
 
39
  ):
40
  # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
41
  # curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>'
@@ -45,6 +46,13 @@ class MessageStreamer:
45
  self.request_headers = {
46
  "Content-Type": "application/json",
47
  }
 
 
 
 
 
 
 
48
  # References:
49
  # huggingface_hub/inference/_client.py:
50
  # class InferenceClient > def text_generation()
 
36
  prompt: str = None,
37
  temperature: float = 0.01,
38
  max_new_tokens: int = 8192,
39
+ api_key: str = None,
40
  ):
41
  # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
42
  # curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>'
 
46
  self.request_headers = {
47
  "Content-Type": "application/json",
48
  }
49
+
50
+ if api_key:
51
+ logger.note(
52
+ f"Using API Key: {api_key[:3]}{(len(api_key)-7)*'*'}{api_key[-4:]}"
53
+ )
54
+ self.request_headers["Authorization"] = f"Bearer {api_key}"
55
+
56
  # References:
57
  # huggingface_hub/inference/_client.py:
58
  # class InferenceClient > def text_generation()