OjciecTadeusz commited on
Commit
e5928ae
1 Parent(s): b58b5c3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +208 -152
main.py CHANGED
@@ -1,165 +1,221 @@
1
-
2
-
3
- from fastapi import FastAPI, HTTPException, Depends
4
- from fastapi.security.api_key import APIKeyHeader
5
  from pydantic import BaseModel
6
- from huggingface_hub import InferenceClient, HfApi
7
- from typing import List, Optional
8
- import os
9
- from dotenv import load_dotenv
10
-
11
- # Load environment variables
12
- load_dotenv()
13
 
14
- # Initialize FastAPI app
15
  app = FastAPI()
16
 
17
- # Get HuggingFace token from environment variable
18
- HF_TOKEN = os.getenv("HF_TOKEN")
19
- if not HF_TOKEN:
20
- raise ValueError("HF_TOKEN environment variable is not set")
21
 
22
- # Setup API key authorization
23
- API_KEY_NAME = "Authorization"
24
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Initialize HuggingFace client
27
- try:
28
- client = InferenceClient(
29
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
30
- token=HF_TOKEN
 
 
 
 
 
 
 
 
31
  )
32
- # Verify token is valid
33
- hf_api = HfApi(token=HF_TOKEN)
34
- hf_api.whoami()
35
- except Exception as e:
36
- raise ValueError(f"Failed to initialize HuggingFace client: {str(e)}")
37
 
38
- class ChatMessage(BaseModel):
39
- role: str
40
- content: str
41
 
42
- class GenerationRequest(BaseModel):
43
- prompt: str
44
- message: Optional[str] = None
45
- system_message: Optional[str] = None
46
- history: Optional[List[ChatMessage]] = None
47
- temperature: Optional[float] = 0.7
48
- top_p: Optional[float] = 0.95
49
-
50
- def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
51
- prompt = ""
52
-
53
- if system_message:
54
- prompt += f"<s>[INST] {system_message} [/INST]</s>"
55
-
56
- if history:
57
- for msg in history:
58
- if msg.role == "user":
59
- prompt += f"<s>[INST] {msg.content} [/INST]"
60
- else:
61
- prompt += f" {msg.content}</s>"
62
-
63
- prompt += f"<s>[INST] {message} [/INST]"
64
- return prompt
65
-
66
- async def verify_token(api_key_header: str = Depends(api_key_header)):
67
- if not api_key_header.startswith("Bearer "):
68
- raise HTTPException(
69
- status_code=401,
70
- detail="Bearer token missing"
71
- )
72
- token = api_key_header.replace("Bearer ", "")
73
- if token != HF_TOKEN:
74
- raise HTTPException(
75
- status_code=401,
76
- detail="Invalid authentication credentials"
77
- )
78
- return token
79
 
80
  @app.post("/generate/")
81
- async def generate_text(
82
- request: GenerationRequest,
83
- token: str = Depends(verify_token)
84
- ):
85
  try:
86
- message = request.prompt if request.prompt else request.message
87
- if not message:
88
- return [
89
- {
90
- "msg": "MSG!"
91
- }
92
- ]
93
-
94
- formatted_prompt = format_prompt(
95
- message=message,
96
- history=request.history,
97
- system_message=request.system_message
98
- )
99
-
100
- response = client.text_generation(
101
- formatted_prompt,
102
- temperature=max(request.temperature, 0.01),
103
- top_p=request.top_p,
104
- max_new_tokens=1048,
105
- do_sample=True,
106
- return_full_text=False
107
- )
108
-
109
- if not response:
110
- return [
111
- {
112
- "detail": [
113
- {
114
- # "type": "server_error",
115
- "loc": ["server"],
116
- "msg": "No response received from model",
117
- "input": None
118
- }
119
- ]
120
- }
121
- ]
122
-
123
- # Construct the custom JSON response
124
- return [
125
- {
126
- "msg": response
127
- # "msg": [
128
- # {
129
- # # "type": "success",
130
- # # "loc":[
131
- # # "body",
132
- # # "prompt"
133
- # # ],
134
- # # "loc": ["body"],
135
- # # "msg": [
136
- # # response,
137
- # # formatted_prompt
138
- # # ],
139
-
140
- # }
141
- # ]
142
- }
143
- ]
144
-
145
  except Exception as e:
146
- return [
147
- {
148
- "detail": [
149
- {
150
- "type": "server_error",
151
- "loc": ["server"],
152
- "msg": f"Error generating response: {str(e)}",
153
- "input": None
154
- }
155
- ]
156
- }
157
- ]
158
-
159
- @app.get("/health")
160
- async def health_check():
161
- return {
162
- "status": "healthy",
163
- "huggingface_client": "initialized",
164
- "auth_required": True
165
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
 
 
 
2
  from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
+ import uvicorn
 
 
 
 
 
5
 
 
6
  app = FastAPI()
7
 
8
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
9
 
10
+ class Item(BaseModel):
11
+ prompt: str
12
+ history: list
13
+ system_prompt: str
14
+ temperature: float = 0.01
15
+ top_p: float = 1.0
16
+ details: bool = True
17
+ return_full_text: bool = False
18
+ stream: bool = False
19
+
20
+ def format_prompt(message, history):
21
+ prompt = "<s>"
22
+ for user_prompt, bot_response in history:
23
+ prompt += f"[INST] {user_prompt} [/INST]"
24
+ prompt += f" {bot_response}</s> "
25
+ prompt += f"[INST] {message} [/INST]"
26
+ return prompt
27
 
28
+ def generate(item: Item):
29
+ temperature = float(item.temperature)
30
+ if temperature < 1e-2:
31
+ temperature = 1e-2
32
+ top_p = float(item.top_p)
33
+
34
+ generate_kwargs = dict(
35
+ temperature=temperature,
36
+ max_new_tokens=1048,
37
+ top_p=top_p,
38
+ repetition_penalty=1.0,
39
+ do_sample=True,
40
+ seed=42,
41
  )
 
 
 
 
 
42
 
43
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text)
45
+ output = ""
46
 
47
+ for response in stream:
48
+ output += response.token.text
49
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @app.post("/generate/")
52
+ async def generate_text(item: Item):
 
 
 
53
  try:
54
+ response = generate(item)
55
+ return {"response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
57
+ raise HTTPException(status_code=500, detail=str(e))
58
+
59
+ # from fastapi import FastAPI, HTTPException, Depends
60
+ # from fastapi.security.api_key import APIKeyHeader
61
+ # from pydantic import BaseModel
62
+ # from huggingface_hub import InferenceClient, HfApi
63
+ # from typing import List, Optional
64
+ # import os
65
+ # from dotenv import load_dotenv
66
+
67
+ # # Load environment variables
68
+ # load_dotenv()
69
+
70
+ # # Initialize FastAPI app
71
+ # app = FastAPI()
72
+
73
+ # # Get HuggingFace token from environment variable
74
+ # HF_TOKEN = os.getenv("HF_TOKEN")
75
+ # if not HF_TOKEN:
76
+ # raise ValueError("HF_TOKEN environment variable is not set")
77
+
78
+ # # Setup API key authorization
79
+ # API_KEY_NAME = "Authorization"
80
+ # api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
81
+
82
+ # # Initialize HuggingFace client
83
+ # try:
84
+ # client = InferenceClient(
85
+ # "mistralai/Mixtral-8x7B-Instruct-v0.1",
86
+ # token=HF_TOKEN
87
+ # )
88
+ # # Verify token is valid
89
+ # hf_api = HfApi(token=HF_TOKEN)
90
+ # hf_api.whoami()
91
+ # except Exception as e:
92
+ # raise ValueError(f"Failed to initialize HuggingFace client: {str(e)}")
93
+
94
+ # class ChatMessage(BaseModel):
95
+ # role: str
96
+ # content: str
97
+
98
+ # class GenerationRequest(BaseModel):
99
+ # prompt: str
100
+ # message: Optional[str] = None
101
+ # system_message: Optional[str] = None
102
+ # history: Optional[List[ChatMessage]] = None
103
+ # temperature: Optional[float] = 0.7
104
+ # top_p: Optional[float] = 0.95
105
+
106
+ # def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
107
+ # prompt = ""
108
+
109
+ # if system_message:
110
+ # prompt += f"<s>[INST] {system_message} [/INST]</s>"
111
+
112
+ # if history:
113
+ # for msg in history:
114
+ # if msg.role == "user":
115
+ # prompt += f"<s>[INST] {msg.content} [/INST]"
116
+ # else:
117
+ # prompt += f" {msg.content}</s>"
118
+
119
+ # prompt += f"<s>[INST] {message} [/INST]"
120
+ # return prompt
121
+
122
+ # async def verify_token(api_key_header: str = Depends(api_key_header)):
123
+ # if not api_key_header.startswith("Bearer "):
124
+ # raise HTTPException(
125
+ # status_code=401,
126
+ # detail="Bearer token missing"
127
+ # )
128
+ # token = api_key_header.replace("Bearer ", "")
129
+ # if token != HF_TOKEN:
130
+ # raise HTTPException(
131
+ # status_code=401,
132
+ # detail="Invalid authentication credentials"
133
+ # )
134
+ # return token
135
+
136
+ # @app.post("/generate/")
137
+ # async def generate_text(
138
+ # request: GenerationRequest,
139
+ # token: str = Depends(verify_token)
140
+ # ):
141
+ # try:
142
+ # message = request.prompt if request.prompt else request.message
143
+ # if not message:
144
+ # return [
145
+ # {
146
+ # "msg": "MSG!"
147
+ # }
148
+ # ]
149
+
150
+ # formatted_prompt = format_prompt(
151
+ # message=message,
152
+ # history=request.history,
153
+ # system_message=request.system_message
154
+ # )
155
+
156
+ # response = client.text_generation(
157
+ # formatted_prompt,
158
+ # temperature=max(request.temperature, 0.01),
159
+ # top_p=request.top_p,
160
+ # max_new_tokens=1048,
161
+ # do_sample=True,
162
+ # return_full_text=False
163
+ # )
164
+
165
+ # if not response:
166
+ # return [
167
+ # {
168
+ # "detail": [
169
+ # {
170
+ # # "type": "server_error",
171
+ # "loc": ["server"],
172
+ # "msg": "No response received from model",
173
+ # "input": None
174
+ # }
175
+ # ]
176
+ # }
177
+ # ]
178
+
179
+ # # Construct the custom JSON response
180
+ # return [
181
+ # {
182
+ # "msg": response
183
+ # # "msg": [
184
+ # # {
185
+ # # # "type": "success",
186
+ # # # "loc":[
187
+ # # # "body",
188
+ # # # "prompt"
189
+ # # # ],
190
+ # # # "loc": ["body"],
191
+ # # # "msg": [
192
+ # # # response,
193
+ # # # formatted_prompt
194
+ # # # ],
195
+
196
+ # # }
197
+ # # ]
198
+ # }
199
+ # ]
200
+
201
+ # except Exception as e:
202
+ # return [
203
+ # {
204
+ # "detail": [
205
+ # {
206
+ # "type": "server_error",
207
+ # "loc": ["server"],
208
+ # "msg": f"Error generating response: {str(e)}",
209
+ # "input": None
210
+ # }
211
+ # ]
212
+ # }
213
+ # ]
214
+
215
+ # @app.get("/health")
216
+ # async def health_check():
217
+ # return {
218
+ # "status": "healthy",
219
+ # "huggingface_client": "initialized",
220
+ # "auth_required": True
221
+ # }