Ashrafb commited on
Commit
352772b
1 Parent(s): 143343e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -50
main.py CHANGED
@@ -1,57 +1,49 @@
1
- from fastapi import FastAPI, Request, Form
2
- from fastapi.responses import HTMLResponse, FileResponse
3
- from fastapi.staticfiles import StaticFiles
4
  from huggingface_hub import InferenceClient
5
- import logging
6
-
7
- # Initialize the logger
8
- logging.basicConfig(level=logging.INFO) # Adjust the logging level as needed
9
- logger = logging.getLogger(__name__)
10
-
11
- # Hugging Face Inference Client
12
- client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
13
 
14
  app = FastAPI()
15
 
16
- # Format the prompt for the model
17
- def format_prompt(message, history):
18
- prompt = "<s>"
19
- for user_prompt, bot_response in history:
20
- prompt += f"[INST] {user_prompt} [/INST]"
21
- prompt += f" {bot_response}</s> "
22
- prompt += f"[INST] {message} [/INST]"
23
- return prompt
24
-
25
- # Generate response from the model
26
- def generate(prompt: str, history: list, temperature: float = 0.9, max_new_tokens: int = 512, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.0) -> str:
27
- try:
28
- formatted_prompt = format_prompt(prompt, history)
29
- logger.info(f"Formatted prompt: {formatted_prompt}")
30
- bot_response = client.text_generation(
31
- formatted_prompt, temperature=temperature, max_new_tokens=max_new_tokens,
32
- top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, stream=True,
33
- details=True, return_full_text=False
34
- )
35
- output = [response.token.text.strip() for response in bot_response if response.token.text.strip()]
36
- logger.info(f"Bot response tokens: {output}")
37
- return " ".join(output)
38
- except Exception as e:
39
- logger.error(f"Error generating text: {e}")
40
- return ""
41
-
42
- @app.post("/generate/")
43
- async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), top_k: int = Form(50), repetition_penalty: float = Form(1.0)):
44
- history = eval(history) # Convert history string back to list
45
- response = generate(prompt, history, temperature, max_new_tokens, top_p, top_k, repetition_penalty)
46
-
47
- # Remove any HTML tags from the response
48
- import re
49
- response = re.sub('<[^<]+?>', '', response)
50
-
51
- return {"response": response}
52
 
53
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
 
 
 
 
 
54
 
55
  @app.get("/")
56
- def index() -> FileResponse:
57
- return FileResponse(path="static/index.html", media_type="text/html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket
2
+ from fastapi.responses import HTMLResponse
3
+ from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
+ import json
 
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ class MessageRequest(BaseModel):
12
+ message: str
13
+ history: list[tuple[str, str]]
14
+ system_message: str
15
+ max_tokens: int
16
+ temperature: float
17
+ top_p: float
18
 
19
  @app.get("/")
20
+ async def get():
21
+ with open("index.html", "r") as file:
22
+ return HTMLResponse(content=file.read(), media_type="text/html")
23
+
24
+ @app.websocket("/ws")
25
+ async def websocket_endpoint(websocket: WebSocket):
26
+ await websocket.accept()
27
+ while True:
28
+ data = await websocket.receive_text()
29
+ request = MessageRequest(**json.loads(data))
30
+ messages = [{"role": "system", "content": request.system_message}]
31
+ for val in request.history:
32
+ if val[0]:
33
+ messages.append({"role": "user", "content": val[0]})
34
+ if val[1]:
35
+ messages.append({"role": "assistant", "content": val[1]})
36
+ messages.append({"role": "user", "content": request.message})
37
+
38
+ response = ""
39
+ for message in client.chat_completion(
40
+ messages,
41
+ max_tokens=request.max_tokens,
42
+ stream=True,
43
+ temperature=request.temperature,
44
+ top_p=request.top_p,
45
+ ):
46
+ token = message.choices[0].delta.content
47
+ response += token
48
+ await websocket.send_text(json.dumps({"token": token}))
49
+