graychensz commited on
Commit
f13e39b
·
verified ·
1 Parent(s): 0fd53b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import json
6
+ import uuid
7
+
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import torch
12
+ import uvicorn
13
+
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from threading import Thread
16
+ from queue import Queue
17
+
18
+
19
+ class TokenStreamer(BaseStreamer):
20
+ def __init__(self, skip_prompt: bool = False, timeout=None):
21
+ self.skip_prompt = skip_prompt
22
+
23
+ # variables used in the streaming process
24
+ self.token_queue = Queue()
25
+ self.stop_signal = None
26
+ self.next_tokens_are_prompt = True
27
+ self.timeout = timeout
28
+
29
+ def put(self, value):
30
+ if len(value.shape) > 1 and value.shape[0] > 1:
31
+ raise ValueError("TextStreamer only supports batch size 1")
32
+ elif len(value.shape) > 1:
33
+ value = value[0]
34
+
35
+ if self.skip_prompt and self.next_tokens_are_prompt:
36
+ self.next_tokens_are_prompt = False
37
+ return
38
+
39
+ for token in value.tolist():
40
+ self.token_queue.put(token)
41
+
42
+ def end(self):
43
+ self.token_queue.put(self.stop_signal)
44
+
45
+ def __iter__(self):
46
+ return self
47
+
48
+ def __next__(self):
49
+ value = self.token_queue.get(timeout=self.timeout)
50
+ if value == self.stop_signal:
51
+ raise StopIteration()
52
+ else:
53
+ return value
54
+
55
+
56
+ class ModelWorker:
57
+ def __init__(self, model_path, device='cuda'):
58
+ self.device = device
59
+ self.glm_model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
60
+ device=device).to(device).eval()
61
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
62
+
63
+ @torch.inference_mode()
64
+ def generate_stream(self, params):
65
+ tokenizer, model = self.glm_tokenizer, self.glm_model
66
+
67
+ prompt = params["prompt"]
68
+
69
+ temperature = float(params.get("temperature", 1.0))
70
+ top_p = float(params.get("top_p", 1.0))
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+
73
+ inputs = tokenizer([prompt], return_tensors="pt")
74
+ inputs = inputs.to(self.device)
75
+ streamer = TokenStreamer(skip_prompt=True)
76
+ thread = Thread(target=model.generate,
77
+ kwargs=dict(**inputs, max_new_tokens=int(max_new_tokens),
78
+ temperature=float(temperature), top_p=float(top_p),
79
+ streamer=streamer))
80
+ thread.start()
81
+ for token_id in streamer:
82
+ yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
83
+
84
+ def generate_stream_gate(self, params):
85
+ try:
86
+ for x in self.generate_stream(params):
87
+ yield x
88
+ except Exception as e:
89
+ print("Caught Unknown Error", e)
90
+ ret = {
91
+ "text": "Server Error",
92
+ "error_code": 1,
93
+ }
94
+ yield (json.dumps(ret)+ "\n").encode()
95
+
96
+
97
+ app = FastAPI()
98
+
99
+
100
+ @app.post("/generate_stream")
101
+ async def generate_stream(request: Request):
102
+ params = await request.json()
103
+
104
+ generator = worker.generate_stream_gate(params)
105
+ return StreamingResponse(generator)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--host", type=str, default="localhost")
111
+ parser.add_argument("--port", type=int, default=10000)
112
+ parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
113
+ args = parser.parse_args()
114
+
115
+ worker = ModelWorker(args.model_path)
116
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")