Valtry commited on
Commit
16c8676
Β·
verified Β·
1 Parent(s): 03b0e50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -67
app.py CHANGED
@@ -1,9 +1,7 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
- from threading import Thread
7
  import uvicorn
8
 
9
  # -----------------------
@@ -12,6 +10,7 @@ import uvicorn
12
  MODEL_ID = "microsoft/phi-2"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  device_map="cpu",
@@ -26,36 +25,20 @@ torch.set_num_threads(2)
26
  # -----------------------
27
  app = FastAPI()
28
 
29
- # stop flag (global)
30
- stop_generation = False
31
-
32
  class ChatRequest(BaseModel):
33
  message: str
34
 
35
 
36
  @app.get("/")
37
  def home():
38
- return {"status": "Streaming API running πŸš€"}
39
-
40
-
41
- # -----------------------
42
- # STOP ENDPOINT
43
- # -----------------------
44
- @app.post("/stop")
45
- def stop():
46
- global stop_generation
47
- stop_generation = True
48
- return {"status": "stopping"}
49
 
50
 
51
  # -----------------------
52
- # STREAMING CHAT
53
  # -----------------------
54
  @app.post("/chat")
55
- async def chat(req: ChatRequest):
56
-
57
- global stop_generation
58
- stop_generation = False
59
 
60
  prompt = f"""You are a concise assistant.
61
  Return plain text only.
@@ -70,58 +53,34 @@ Assistant:"""
70
 
71
  inputs = tokenizer(prompt, return_tensors="pt")
72
 
73
- streamer = TextIteratorStreamer(
74
- tokenizer,
75
- skip_prompt=True,
76
- skip_special_tokens=True
 
 
 
77
  )
78
 
79
- # βœ… Define stop tokens
80
- stop_tokens = ["User:", "\n\n"]
81
-
82
- stop_token_ids = [
83
- tokenizer.encode(token, add_special_tokens=False)
84
- for token in stop_tokens
85
- ]
86
-
87
- def generate():
88
- model.generate(
89
- **inputs,
90
- streamer=streamer,
91
- max_new_tokens=100,
92
- temperature=0.5,
93
- do_sample=True,
94
- eos_token_id=tokenizer.eos_token_id,
95
- pad_token_id=tokenizer.eos_token_id
96
- )
97
 
98
- Thread(target=generate).start()
 
 
99
 
100
- async def stream():
101
- global stop_generation
102
 
103
- buffer_ids = []
104
 
105
- for token in streamer:
106
- if stop_generation:
107
- break
108
 
109
- # convert token β†’ ids
110
- token_ids = tokenizer.encode(token, add_special_tokens=False)
111
- buffer_ids.extend(token_ids)
112
 
113
- # πŸ”₯ STOP TOKEN CHECK (clean, not hacky)
114
- for stop_seq in stop_token_ids:
115
- if buffer_ids[-len(stop_seq):] == stop_seq:
116
- return
117
-
118
- yield token
119
-
120
- return StreamingResponse(
121
- stream(),
122
- media_type="text/plain",
123
- headers={"Transfer-Encoding": "identity"}
124
- )
125
 
126
  # -----------------------
127
  # START SERVER
 
1
+ from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
  import uvicorn
6
 
7
  # -----------------------
 
10
  MODEL_ID = "microsoft/phi-2"
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
+
14
  model = AutoModelForCausalLM.from_pretrained(
15
  MODEL_ID,
16
  device_map="cpu",
 
25
  # -----------------------
26
  app = FastAPI()
27
 
 
 
 
28
  class ChatRequest(BaseModel):
29
  message: str
30
 
31
 
32
  @app.get("/")
33
  def home():
34
+ return {"status": "API running πŸš€"}
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  # -----------------------
38
+ # CHAT (NO STREAMING)
39
  # -----------------------
40
  @app.post("/chat")
41
+ def chat(req: ChatRequest):
 
 
 
42
 
43
  prompt = f"""You are a concise assistant.
44
  Return plain text only.
 
53
 
54
  inputs = tokenizer(prompt, return_tensors="pt")
55
 
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=80,
59
+ temperature=0.5,
60
+ do_sample=True,
61
+ eos_token_id=tokenizer.eos_token_id,
62
+ pad_token_id=tokenizer.eos_token_id
63
  )
64
 
65
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # πŸ”₯ CLEAN OUTPUT
68
+ if "Assistant:" in text:
69
+ text = text.split("Assistant:")[-1]
70
 
71
+ if "User:" in text:
72
+ text = text.split("User:")[0]
73
 
74
+ text = text.strip()
75
 
76
+ # remove unwanted formatting
77
+ text = text.replace("\n", " ")
78
+ text = text.replace(" ", " ")
79
 
80
+ return {
81
+ "response": text
82
+ }
83
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # -----------------------
86
  # START SERVER