OjciecTadeusz commited on
Commit
cc8c305
1 Parent(s): 9b6975c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -19
main.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
  import uvicorn
@@ -9,15 +10,16 @@ from dotenv import load_dotenv
9
  # Load environment variables
10
  load_dotenv()
11
 
12
- # Initialize FastAPI app
13
  app = FastAPI()
 
14
 
15
- # Get Hugging Face token from environment variable
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  if not HF_TOKEN:
18
- raise ValueError("HF_TOKEN environment variable not set")
19
 
20
- # Initialize Hugging Face client with token
21
  client = InferenceClient(
22
  model="mistralai/Mixtral-8x7B-Instruct-v0.1",
23
  token=HF_TOKEN
@@ -36,10 +38,10 @@ class GenerationRequest(BaseModel):
36
  top_p: Optional[float] = 0.95
37
 
38
  def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
39
- prompt = "<s>"
40
 
41
  if system_message:
42
- prompt += f"[INST] {system_message} [/INST]</s>"
43
 
44
  if history:
45
  for msg in history:
@@ -51,37 +53,59 @@ def format_prompt(message: str, history: List[ChatMessage] = None, system_messag
51
  prompt += f"<s>[INST] {message} [/INST]"
52
  return prompt
53
 
 
 
 
 
 
 
 
 
54
  @app.post("/generate/")
55
- async def generate_text(request: GenerationRequest):
 
 
 
56
  try:
57
  message = request.prompt if request.prompt else request.message
58
  if not message:
59
- raise HTTPException(status_code=400, detail="Either 'prompt' or 'message' must be provided")
 
 
 
60
 
61
- # Format the prompt
62
  formatted_prompt = format_prompt(
63
  message=message,
64
  history=request.history,
65
  system_message=request.system_message
66
  )
67
 
68
- # Make the request to Hugging Face
 
 
 
 
 
 
 
69
  response = client.text_generation(
70
  formatted_prompt,
71
- max_new_tokens=1024,
72
- temperature=max(request.temperature, 0.01),
73
- top_p=request.top_p,
74
- do_sample=True,
75
- seed=42
76
  )
77
-
78
  if not response:
79
- raise HTTPException(status_code=500, detail="Empty response from model")
 
 
 
80
 
81
  return {"response": response}
82
 
83
  except Exception as e:
84
- raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
 
 
 
85
 
86
  @app.get("/health")
87
  async def health_check():
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
 
10
  # Load environment variables
11
  load_dotenv()
12
 
13
+ # Initialize FastAPI app and security
14
  app = FastAPI()
15
+ security = HTTPBearer()
16
 
17
+ # Get HuggingFace token from environment variable
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
  if not HF_TOKEN:
20
+ raise ValueError("HF_TOKEN environment variable is not set")
21
 
22
+ # Initialize HuggingFace client with token
23
  client = InferenceClient(
24
  model="mistralai/Mixtral-8x7B-Instruct-v0.1",
25
  token=HF_TOKEN
 
38
  top_p: Optional[float] = 0.95
39
 
40
  def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
41
+ prompt = ""
42
 
43
  if system_message:
44
+ prompt += f"<s>[INST] {system_message} [/INST]</s>"
45
 
46
  if history:
47
  for msg in history:
 
53
  prompt += f"<s>[INST] {message} [/INST]"
54
  return prompt
55
 
56
+ async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
57
+ if credentials.credentials != HF_TOKEN:
58
+ raise HTTPException(
59
+ status_code=401,
60
+ detail="Invalid authentication credentials"
61
+ )
62
+ return credentials.credentials
63
+
64
  @app.post("/generate/")
65
+ async def generate_text(
66
+ request: GenerationRequest,
67
+ token: str = Depends(verify_token)
68
+ ):
69
  try:
70
  message = request.prompt if request.prompt else request.message
71
  if not message:
72
+ raise HTTPException(
73
+ status_code=400,
74
+ detail="Either 'prompt' or 'message' must be provided"
75
+ )
76
 
 
77
  formatted_prompt = format_prompt(
78
  message=message,
79
  history=request.history,
80
  system_message=request.system_message
81
  )
82
 
83
+ parameters = {
84
+ "temperature": max(request.temperature, 0.01),
85
+ "top_p": request.top_p,
86
+ "max_new_tokens": 1048,
87
+ "do_sample": True,
88
+ "return_full_text": False
89
+ }
90
+
91
  response = client.text_generation(
92
  formatted_prompt,
93
+ **parameters
 
 
 
 
94
  )
95
+
96
  if not response:
97
+ raise HTTPException(
98
+ status_code=500,
99
+ detail="No response received from model"
100
+ )
101
 
102
  return {"response": response}
103
 
104
  except Exception as e:
105
+ raise HTTPException(
106
+ status_code=500,
107
+ detail=f"Error generating response: {str(e)}"
108
+ )
109
 
110
  @app.get("/health")
111
  async def health_check():