Ashrafb commited on
Commit
8c3bfa5
1 Parent(s): c305f1e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -28
main.py CHANGED
@@ -1,15 +1,7 @@
1
- from fastapi import FastAPI, File, UploadFile
2
  from fastapi import FastAPI, File, UploadFile, Form, Request
3
  from fastapi.responses import HTMLResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
- from fastapi import FastAPI, File, UploadFile, HTTPException
7
- from fastapi.responses import JSONResponse
8
- from fastapi.responses import StreamingResponse
9
- from fastapi import FastAPI, Request, Form
10
- from fastapi.responses import HTMLResponse
11
- from fastapi.staticfiles import StaticFiles
12
- from fastapi.templating import Jinja2Templates
13
  from huggingface_hub import InferenceClient
14
  import random
15
 
@@ -23,12 +15,13 @@ app = FastAPI()
23
 
24
 
25
  def format_prompt(message, history):
26
- prompt = "<s>"
27
- for user_prompt, bot_response in history:
28
- prompt += f"[INST] {user_prompt} [/INST]"
29
- prompt += f" {bot_response}</s> "
30
- prompt += f"[INST] {message} [/INST]"
31
- return prompt
 
32
 
33
  def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
34
  temperature = float(temperature)
@@ -49,6 +42,7 @@ def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, r
49
 
50
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
51
  output = ""
 
52
 
53
  for response in stream:
54
  token_text = response.token.text.strip()
@@ -56,21 +50,15 @@ def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, r
56
  # Decode the token text to handle encoded characters
57
  decoded_text = token_text.encode("utf-8", "backslashreplace").decode("utf-8")
58
 
59
- # Add the decoded token text to the output
60
- output += decoded_text
61
-
62
- return output
63
-
64
-
65
-
66
-
67
-
68
-
69
-
70
-
71
-
72
 
 
 
 
 
73
 
 
74
 
75
 
76
  @app.post("/generate/")
@@ -88,4 +76,3 @@ app.mount("/", StaticFiles(directory="static", html=True), name="static")
88
  @app.get("/")
89
  def index() -> FileResponse:
90
  return FileResponse(path="/app/static/index.html", media_type="text/html")
91
-
 
 
1
  from fastapi import FastAPI, File, UploadFile, Form, Request
2
  from fastapi.responses import HTMLResponse, FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
 
 
 
 
 
 
 
5
  from huggingface_hub import InferenceClient
6
  import random
7
 
 
15
 
16
 
17
  def format_prompt(message, history):
18
+ prompt = "<s>"
19
+ for user_prompt, bot_response in history:
20
+ prompt += f"[INST] {user_prompt} [/INST]"
21
+ prompt += f" {bot_response}</s> "
22
+ prompt += f"[INST] {message} [/INST]"
23
+ return prompt
24
+
25
 
26
  def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
27
  temperature = float(temperature)
 
42
 
43
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
44
  output = ""
45
+ word = ""
46
 
47
  for response in stream:
48
  token_text = response.token.text.strip()
 
50
  # Decode the token text to handle encoded characters
51
  decoded_text = token_text.encode("utf-8", "backslashreplace").decode("utf-8")
52
 
53
+ # Add the decoded letter to the current word
54
+ word += decoded_text
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # If the token is a space or the end of the stream, add the word to the output and reset the word
57
+ if token_text == " " or response.is_end_of_stream:
58
+ output += word + " "
59
+ word = ""
60
 
61
+ return output
62
 
63
 
64
  @app.post("/generate/")
 
76
  @app.get("/")
77
  def index() -> FileResponse:
78
  return FileResponse(path="/app/static/index.html", media_type="text/html")