Ashrafb commited on
Commit
94ca598
1 Parent(s): 9c5e564

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -42
main.py CHANGED
@@ -1,46 +1,37 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi import FastAPI, File, UploadFile, Form, Request
3
- from fastapi.responses import HTMLResponse, FileResponse
4
- from fastapi.staticfiles import StaticFiles
5
- from fastapi.templating import Jinja2Templates
6
- from fastapi import FastAPI, File, UploadFile, HTTPException
7
- from fastapi.responses import JSONResponse
8
- from fastapi.responses import StreamingResponse
9
  from fastapi import FastAPI, Request, Form
10
- from fastapi.responses import HTMLResponse
11
  from fastapi.staticfiles import StaticFiles
12
- from fastapi.templating import Jinja2Templates
13
  from huggingface_hub import InferenceClient
14
- import random
15
-
16
- API_URL = "https://api-inference.huggingface.co/models/"
17
-
18
- client = InferenceClient(
19
- "mistralai/Mistral-7B-Instruct-v0.1"
20
- )
21
-
22
- app = FastAPI()
23
-
24
-
25
- def format_prompt(message, history):
26
- prompt = "<s>"
27
- for user_prompt, bot_response in history:
28
- prompt += f"[INST] {user_prompt} [/INST]"
29
- prompt += f" {bot_response}</s> "
30
- prompt += f"[INST] {message} [/INST]"
31
- return prompt
32
-
33
  import logging
34
 
35
  # Initialize the logger
36
  logging.basicConfig(level=logging.INFO) # Adjust the logging level as needed
37
  logger = logging.getLogger(__name__)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def generate(prompt: str, history: list, temperature: float = 0.9, max_new_tokens: int = 512, top_p: float = 0.95, repetition_penalty: float = 1.0) -> str:
40
  try:
41
  formatted_prompt = format_prompt(prompt, history)
42
  logger.info(f"Formatted prompt: {formatted_prompt}")
43
- bot_response = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, stream=True, details=True, return_full_text=False)
 
 
 
 
44
  output = [response.token.text.strip() for response in bot_response if response.token.text.strip()]
45
  logger.info(f"Bot response tokens: {output}")
46
  return " ".join(output)
@@ -48,16 +39,6 @@ def generate(prompt: str, history: list, temperature: float = 0.9, max_new_token
48
  logger.error(f"Error generating text: {e}")
49
  return ""
50
 
51
-
52
-
53
-
54
-
55
-
56
-
57
-
58
-
59
-
60
-
61
  @app.post("/generate/")
62
  async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
63
  history = eval(history) # Convert history string back to list
@@ -73,5 +54,4 @@ app.mount("/", StaticFiles(directory="static", html=True), name="static")
73
 
74
  @app.get("/")
75
  def index() -> FileResponse:
76
- return FileResponse(path="/app/static/index.html", media_type="text/html")
77
-
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request, Form
2
+ from fastapi.responses import HTMLResponse, FileResponse
3
  from fastapi.staticfiles import StaticFiles
 
4
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import logging
6
 
7
  # Initialize the logger
8
  logging.basicConfig(level=logging.INFO) # Adjust the logging level as needed
9
  logger = logging.getLogger(__name__)
10
 
11
+ # Hugging Face Inference Client
12
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
13
+
14
+ app = FastAPI()
15
+
16
+ # Format the prompt for the model
17
+ def format_prompt(message, history):
18
+ prompt = "<s>"
19
+ for user_prompt, bot_response in history:
20
+ prompt += f"[INST] {user_prompt} [/INST]"
21
+ prompt += f" {bot_response}</s> "
22
+ prompt += f"[INST] {message} [/INST]"
23
+ return prompt
24
+
25
+ # Generate response from the model
26
  def generate(prompt: str, history: list, temperature: float = 0.9, max_new_tokens: int = 512, top_p: float = 0.95, repetition_penalty: float = 1.0) -> str:
27
  try:
28
  formatted_prompt = format_prompt(prompt, history)
29
  logger.info(f"Formatted prompt: {formatted_prompt}")
30
+ bot_response = client.text_generation(
31
+ formatted_prompt, temperature=temperature, max_new_tokens=max_new_tokens,
32
+ top_p=top_p, repetition_penalty=repetition_penalty, stream=True,
33
+ details=True, return_full_text=False
34
+ )
35
  output = [response.token.text.strip() for response in bot_response if response.token.text.strip()]
36
  logger.info(f"Bot response tokens: {output}")
37
  return " ".join(output)
 
39
  logger.error(f"Error generating text: {e}")
40
  return ""
41
 
 
 
 
 
 
 
 
 
 
 
42
  @app.post("/generate/")
43
  async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
44
  history = eval(history) # Convert history string back to list
 
54
 
55
  @app.get("/")
56
  def index() -> FileResponse:
57
+ return FileResponse(path="static/index.html", media_type="text/html")