Spaces:
Sleeping
Sleeping
OjciecTadeusz
commited on
Commit
•
cc8c305
1
Parent(s):
9b6975c
Update main.py
Browse files
main.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from huggingface_hub import InferenceClient
|
4 |
import uvicorn
|
@@ -9,15 +10,16 @@ from dotenv import load_dotenv
|
|
9 |
# Load environment variables
|
10 |
load_dotenv()
|
11 |
|
12 |
-
# Initialize FastAPI app
|
13 |
app = FastAPI()
|
|
|
14 |
|
15 |
-
# Get
|
16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
17 |
if not HF_TOKEN:
|
18 |
-
raise ValueError("HF_TOKEN environment variable not set")
|
19 |
|
20 |
-
# Initialize
|
21 |
client = InferenceClient(
|
22 |
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
23 |
token=HF_TOKEN
|
@@ -36,10 +38,10 @@ class GenerationRequest(BaseModel):
|
|
36 |
top_p: Optional[float] = 0.95
|
37 |
|
38 |
def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
|
39 |
-
prompt = "
|
40 |
|
41 |
if system_message:
|
42 |
-
prompt += f"[INST] {system_message} [/INST]</s>"
|
43 |
|
44 |
if history:
|
45 |
for msg in history:
|
@@ -51,37 +53,59 @@ def format_prompt(message: str, history: List[ChatMessage] = None, system_messag
|
|
51 |
prompt += f"<s>[INST] {message} [/INST]"
|
52 |
return prompt
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
@app.post("/generate/")
|
55 |
-
async def generate_text(
|
|
|
|
|
|
|
56 |
try:
|
57 |
message = request.prompt if request.prompt else request.message
|
58 |
if not message:
|
59 |
-
raise HTTPException(
|
|
|
|
|
|
|
60 |
|
61 |
-
# Format the prompt
|
62 |
formatted_prompt = format_prompt(
|
63 |
message=message,
|
64 |
history=request.history,
|
65 |
system_message=request.system_message
|
66 |
)
|
67 |
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
response = client.text_generation(
|
70 |
formatted_prompt,
|
71 |
-
|
72 |
-
temperature=max(request.temperature, 0.01),
|
73 |
-
top_p=request.top_p,
|
74 |
-
do_sample=True,
|
75 |
-
seed=42
|
76 |
)
|
77 |
-
|
78 |
if not response:
|
79 |
-
raise HTTPException(
|
|
|
|
|
|
|
80 |
|
81 |
return {"response": response}
|
82 |
|
83 |
except Exception as e:
|
84 |
-
raise HTTPException(
|
|
|
|
|
|
|
85 |
|
86 |
@app.get("/health")
|
87 |
async def health_check():
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Depends
|
2 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
3 |
from pydantic import BaseModel
|
4 |
from huggingface_hub import InferenceClient
|
5 |
import uvicorn
|
|
|
10 |
# Load environment variables
|
11 |
load_dotenv()
|
12 |
|
13 |
+
# Initialize FastAPI app and security
|
14 |
app = FastAPI()
|
15 |
+
security = HTTPBearer()
|
16 |
|
17 |
+
# Get HuggingFace token from environment variable
|
18 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
19 |
if not HF_TOKEN:
|
20 |
+
raise ValueError("HF_TOKEN environment variable is not set")
|
21 |
|
22 |
+
# Initialize HuggingFace client with token
|
23 |
client = InferenceClient(
|
24 |
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
25 |
token=HF_TOKEN
|
|
|
38 |
top_p: Optional[float] = 0.95
|
39 |
|
40 |
def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
|
41 |
+
prompt = ""
|
42 |
|
43 |
if system_message:
|
44 |
+
prompt += f"<s>[INST] {system_message} [/INST]</s>"
|
45 |
|
46 |
if history:
|
47 |
for msg in history:
|
|
|
53 |
prompt += f"<s>[INST] {message} [/INST]"
|
54 |
return prompt
|
55 |
|
56 |
+
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
57 |
+
if credentials.credentials != HF_TOKEN:
|
58 |
+
raise HTTPException(
|
59 |
+
status_code=401,
|
60 |
+
detail="Invalid authentication credentials"
|
61 |
+
)
|
62 |
+
return credentials.credentials
|
63 |
+
|
64 |
@app.post("/generate/")
|
65 |
+
async def generate_text(
|
66 |
+
request: GenerationRequest,
|
67 |
+
token: str = Depends(verify_token)
|
68 |
+
):
|
69 |
try:
|
70 |
message = request.prompt if request.prompt else request.message
|
71 |
if not message:
|
72 |
+
raise HTTPException(
|
73 |
+
status_code=400,
|
74 |
+
detail="Either 'prompt' or 'message' must be provided"
|
75 |
+
)
|
76 |
|
|
|
77 |
formatted_prompt = format_prompt(
|
78 |
message=message,
|
79 |
history=request.history,
|
80 |
system_message=request.system_message
|
81 |
)
|
82 |
|
83 |
+
parameters = {
|
84 |
+
"temperature": max(request.temperature, 0.01),
|
85 |
+
"top_p": request.top_p,
|
86 |
+
"max_new_tokens": 1048,
|
87 |
+
"do_sample": True,
|
88 |
+
"return_full_text": False
|
89 |
+
}
|
90 |
+
|
91 |
response = client.text_generation(
|
92 |
formatted_prompt,
|
93 |
+
**parameters
|
|
|
|
|
|
|
|
|
94 |
)
|
95 |
+
|
96 |
if not response:
|
97 |
+
raise HTTPException(
|
98 |
+
status_code=500,
|
99 |
+
detail="No response received from model"
|
100 |
+
)
|
101 |
|
102 |
return {"response": response}
|
103 |
|
104 |
except Exception as e:
|
105 |
+
raise HTTPException(
|
106 |
+
status_code=500,
|
107 |
+
detail=f"Error generating response: {str(e)}"
|
108 |
+
)
|
109 |
|
110 |
@app.get("/health")
|
111 |
async def health_check():
|