ratyim commited on
Commit
40ff596
1 Parent(s): e6137da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import gradio as gr
5
+ import torch
6
+ from threading import Thread
7
+
8
+ from typing import Union
9
+ from pathlib import Path
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ PreTrainedTokenizer,
15
+ PreTrainedTokenizerFast,
16
+ StoppingCriteria,
17
+ StoppingCriteriaList,
18
+ TextIteratorStreamer
19
+ )
20
+
21
+ ModelType = Union[PreTrainedModel]
22
+ TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
23
+
24
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
25
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
26
+
27
+
28
+ def _resolve_path(path: Union[str, Path]) -> Path:
29
+ return Path(path).expanduser().resolve()
30
+
31
+
32
+ def load_model_and_tokenizer(
33
+ model_dir: Union[str, Path], trust_remote_code: bool = True
34
+ ) -> tuple[ModelType, TokenizerType]:
35
+ model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=trust_remote_code, device_map='auto')
36
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=trust_remote_code, use_fast=False)
37
+ return model, tokenizer
38
+
39
+
40
+ model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
41
+
42
+
43
+ class StopOnTokens(StoppingCriteria):
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
45
+ stop_ids = model.config.eos_token_id
46
+ for stop_id in stop_ids:
47
+ if input_ids[0][-1] == stop_id:
48
+ return True
49
+ return False
50
+
51
+
52
+ def parse_text(text):
53
+ lines = text.split("\n")
54
+ lines = [line for line in lines if line != ""]
55
+ count = 0
56
+ for i, line in enumerate(lines):
57
+ if "```" in line:
58
+ count += 1
59
+ items = line.split('`')
60
+ if count % 2 == 1:
61
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
62
+ else:
63
+ lines[i] = f'<br></code></pre>'
64
+ else:
65
+ if i > 0:
66
+ if count % 2 == 1:
67
+ line = line.replace("`", "\`")
68
+ line = line.replace("<", "&lt;")
69
+ line = line.replace(">", "&gt;")
70
+ line = line.replace(" ", "&nbsp;")
71
+ line = line.replace("*", "&ast;")
72
+ line = line.replace("_", "&lowbar;")
73
+ line = line.replace("-", "&#45;")
74
+ line = line.replace(".", "&#46;")
75
+ line = line.replace("!", "&#33;")
76
+ line = line.replace("(", "&#40;")
77
+ line = line.replace(")", "&#41;")
78
+ line = line.replace("$", "&#36;")
79
+ lines[i] = "<br>" + line
80
+ text = "".join(lines)
81
+ return text
82
+
83
+ @spaces.GPU
84
+ def predict(history, max_length, top_p, temperature):
85
+ stop = StopOnTokens()
86
+ messages = []
87
+ for idx, (user_msg, model_msg) in enumerate(history):
88
+ if idx == len(history) - 1 and not model_msg:
89
+ messages.append({"role": "user", "content": user_msg})
90
+ break
91
+ if user_msg:
92
+ messages.append({"role": "user", "content": user_msg})
93
+ if model_msg:
94
+ messages.append({"role": "assistant", "content": model_msg})
95
+
96
+ model_inputs = tokenizer.apply_chat_template(messages,
97
+ add_generation_prompt=True,
98
+ tokenize=True,
99
+ return_tensors="pt").to(next(model.parameters()).device)
100
+ streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
101
+ generate_kwargs = {
102
+ "input_ids": model_inputs,
103
+ "streamer": streamer,
104
+ "max_new_tokens": max_length,
105
+ "do_sample": True,
106
+ "top_p": top_p,
107
+ "temperature": temperature,
108
+ "stopping_criteria": StoppingCriteriaList([stop]),
109
+ "repetition_penalty": 1.2,
110
+ "eos_token_id": model.config.eos_token_id,
111
+ }
112
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
113
+ t.start()
114
+ for new_token in streamer:
115
+ if new_token:
116
+ history[-1][1] += new_token
117
+ yield history
118
+
119
+
120
+ with gr.Blocks() as demo:
121
+ gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
122
+ chatbot = gr.Chatbot()
123
+
124
+ with gr.Row():
125
+ with gr.Column(scale=4):
126
+ with gr.Column(scale=12):
127
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
128
+ with gr.Column(min_width=32, scale=1):
129
+ submitBtn = gr.Button("Submit")
130
+ with gr.Column(scale=1):
131
+ emptyBtn = gr.Button("Clear History")
132
+ max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
133
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
134
+ temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
135
+
136
+
137
+ def user(query, history):
138
+ return "", history + [[parse_text(query), ""]]
139
+
140
+
141
+ submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
142
+ predict, [chatbot, max_length, top_p, temperature], chatbot
143
+ )
144
+ emptyBtn.click(lambda: None, None, chatbot, queue=False)
145
+
146
+ demo.queue().launch()
147
+