omkar56 commited on
Commit
a438652
1 Parent(s): bf69a80

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -31
main.py CHANGED
@@ -1,37 +1,45 @@
1
- from fastapi import FastAPI, Request, Body
 
 
2
  from huggingface_hub import InferenceClient
3
  import random
4
 
5
  API_URL = "https://api-inference.huggingface.co/models/"
 
6
 
7
- client = InferenceClient(
8
- "mistralai/Mistral-7B-Instruct-v0.1"
9
- )
10
-
11
  app = FastAPI()
12
 
13
- def format_prompt(message, history):
14
- prompt = "<s>"
15
- for user_prompt, bot_response in history:
16
- prompt += f"[INST] {user_prompt} [/INST]"
17
- prompt += f" {bot_response}</s> "
18
- prompt += f"[INST] {message} [/INST]"
19
- return prompt
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @app.post("/api/v1/generate_text")
23
- def generate_text(request: Request, prompt: str = Body()):
24
- history = [] # You might need to handle this based on your actual usage
25
- print(f"request + {request}")
26
- temperature = request.headers.get("temperature", 0.5)
27
- # print(f"temperature + {temperature}")
28
- top_p = request.headers.get("top_p", 0.95)
29
- # print(f"top_p + {top_p}")
30
- repetition_penalty = request.headers.get("repetition_penalty", 1.0)
31
- # print(f"repetition_penalty + {repetition_penalty}")
 
32
 
 
33
  formatted_prompt = format_prompt(prompt, history)
34
- print(f"formatted_prompt + {formatted_prompt}")
35
  stream = client.text_generation(
36
  formatted_prompt,
37
  temperature=temperature,
@@ -41,15 +49,8 @@ def generate_text(request: Request, prompt: str = Body()):
41
  do_sample=True,
42
  seed=random.randint(0, 10**7),
43
  stream=False,
44
- details=False,
45
  return_full_text=False
46
  )
47
- # output = ""
48
-
49
- # for response in stream:
50
- # output += response.token.text
51
- # yield output
52
-
53
- # return output[len(output) - 1]
54
 
55
- return stream
 
1
+ from fastapi import FastAPI, Request, Body, HTTPException, Depends
2
+ from fastapi.security import APIKeyHeader
3
+ from typing import Optional
4
  from huggingface_hub import InferenceClient
5
  import random
6
 
7
  API_URL = "https://api-inference.huggingface.co/models/"
8
+ API_KEY = "abcd12345" # Replace with your actual API key
9
 
10
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
 
 
 
11
  app = FastAPI()
12
 
13
+ security = APIKeyHeader(name="api_key", auto_error=False)
 
 
 
 
 
 
14
 
15
+ def get_api_key(api_key: Optional[str] = Depends(security)):
16
+ if api_key is None or api_key != API_KEY:
17
+ raise HTTPException(status_code=401, detail="Unauthorized access")
18
+ return api_key
19
+
20
+ def format_prompt(message, history):
21
+ prompt = "<s>"
22
+ for user_prompt, bot_response in history:
23
+ prompt += f"[INST] {user_prompt} [/INST]"
24
+ prompt += f" {bot_response}</s> "
25
+ prompt += f"[INST] {message} [/INST]"
26
+ return prompt
27
 
28
  @app.post("/api/v1/generate_text")
29
+ def generate_text(
30
+ request: Request,
31
+ body: dict = Body(...),
32
+ api_key: str = Depends(get_api_key)
33
+ ):
34
+ prompt = body.get("prompt", "")
35
+ sys_prompt = body.get("sysPrompt", "")
36
+ temperature = body.get("temperature", 0.5)
37
+ top_p = body.get("top_p", 0.95)
38
+ repetition_penalty = body.get("repetition_penalty", 1.0)
39
 
40
+ history = [] # You might need to handle this based on your actual usage
41
  formatted_prompt = format_prompt(prompt, history)
42
+
43
  stream = client.text_generation(
44
  formatted_prompt,
45
  temperature=temperature,
 
49
  do_sample=True,
50
  seed=random.randint(0, 10**7),
51
  stream=False,
52
+ details=False,
53
  return_full_text=False
54
  )
 
 
 
 
 
 
 
55
 
56
+ return stream