Kaaaaaaa commited on
Commit
4237375
1 Parent(s): 5da23da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -12
app.py CHANGED
@@ -1,16 +1,174 @@
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
3
 
4
- # 加载模型
5
- pipe = pipeline("text-generation", model="IndexTeam/Index-1.9B-Character", trust_remote_code=True)
6
 
7
- # 定义Gradio接口
8
- def generate_text(prompt):
9
- messages = [
10
- {"role": "user", "content": prompt},
11
- ]
12
- result = pipe(messages)
13
- return result[0]['generated_text']
14
 
15
- iface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
16
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from pathlib import Path
3
  import gradio as gr
4
+ import torch
5
+ import argparse
6
+ from threading import Thread
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ TextIteratorStreamer,
11
+ )
12
+ import warnings
13
+ import spaces
14
+ import os
15
 
16
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
 
17
 
18
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
19
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
 
 
 
 
 
20
 
21
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto",
25
+ trust_remote_code=True)
26
+
27
+ def _resolve_path(path: Union[str, Path]) -> Path:
28
+ return Path(path).expanduser().resolve()
29
+
30
+ @spaces.GPU
31
+ def hf_gen(dialog: List, top_k, top_p, temperature, repetition_penalty, max_dec_len):
32
+ """
33
+ Generate model output with Huggingface API
34
+ Args:
35
+ dialog (List): List of dialog messages.
36
+ top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
37
+ top_p (float): Only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
38
+ temperature (float): Strictly positive float value used to modulate the logits distribution.
39
+ repetition_penalty (float): The parameter for repetition penalty.
40
+ max_dec_len (int): The maximum numbers of tokens to generate.
41
+ Yields:
42
+ str: Real-time generation results of HF model.
43
+ """
44
+ inputs = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
45
+ enc = tokenizer(inputs, return_tensors="pt").to("cuda")
46
+ streamer = TextIteratorStreamer(tokenizer, **tokenizer.init_kwargs)
47
+ generation_kwargs = dict(
48
+ enc,
49
+ do_sample=True,
50
+ top_k=int(top_k),
51
+ top_p=float(top_p),
52
+ temperature=float(temperature),
53
+ repetition_penalty=float(repetition_penalty),
54
+ max_new_tokens=int(max_dec_len),
55
+ pad_token_id=tokenizer.eos_token_id,
56
+ streamer=streamer,
57
+ )
58
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
59
+ thread.start()
60
+ answer = ""
61
+ for new_text in streamer:
62
+ answer += new_text
63
+ yield answer[len(inputs):]
64
+
65
+ @spaces.GPU
66
+ def generate(chat_history: List, query, top_k, top_p, temperature, repetition_penalty, max_dec_len, system_message):
67
+ """
68
+ Generate after hitting "submit" button
69
+ Args:
70
+ chat_history (List): List that stores all QA records.
71
+ query (str): Query of current round.
72
+ top_p (float): Only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
73
+ temperature (float): Strictly positive float value used to modulate the logits distribution.
74
+ max_dec_len (int): The maximum numbers of tokens to generate.
75
+ Yields:
76
+ List: Updated chat_history with the current round QA.
77
+ """
78
+ assert query != "", "Input must not be empty!!!"
79
+ # apply chat template
80
+ model_input = []
81
+ if system_message:
82
+ model_input.append({
83
+ "role": "system",
84
+ "content": system_message
85
+ })
86
+ for q, a in chat_history:
87
+ model_input.append({"role": "user", "content": q})
88
+ model_input.append({"role": "assistant", "content": a})
89
+ model_input.append({"role": "user", "content": query})
90
+ # yield model generation
91
+ chat_history.append([query, ""])
92
+ for answer in hf_gen(model_input, top_k, top_p, temperature, repetition_penalty, max_dec_len):
93
+ chat_history[-1][1] = answer.strip(tokenizer.eos_token)
94
+ yield gr.update(value=""), chat_history
95
+
96
+ @spaces.GPU
97
+ def regenerate(chat_history: List, top_k, top_p, temperature, repetition_penalty, max_dec_len, system_message):
98
+ """
99
+ Re-generate the answer of last round's query
100
+ Args:
101
+ chat_history (List): List that stores all QA records.
102
+ top_p (float): Only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
103
+ temperature (float): Strictly positive float value used to modulate the logits distribution.
104
+ max_dec_len (int): The maximum numbers of tokens to generate.
105
+ Yields:
106
+ List: Updated chat_history.
107
+ """
108
+ assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
109
+ # apply chat template
110
+ model_input = []
111
+ if system_message:
112
+ model_input.append({
113
+ "role": "system",
114
+ "content": system_message
115
+ })
116
+ for q, a in chat_history[:-1]:
117
+ model_input.append({"role": "user", "content": q})
118
+ model_input.append({"role": "assistant", "content": a})
119
+ model_input.append({"role": "user", "content": chat_history[-1][0]})
120
+ # yield model generation
121
+ for answer in hf_gen(model_input, top_k, top_p, temperature, repetition_penalty, max_dec_len):
122
+ chat_history[-1][1] = answer.strip(tokenizer.eos_token)
123
+ yield gr.update(value=""), chat_history
124
+
125
+ def clear_history():
126
+ """
127
+ Clear all chat history
128
+ Returns:
129
+ List: Empty chat history
130
+ """
131
+ torch.cuda.empty_cache()
132
+ return []
133
+
134
+ def reverse_last_round(chat_history):
135
+ """
136
+ Reverse last round QA and keep the chat history before
137
+ Args:
138
+ chat_history (List): List that stores all QA records.
139
+ Returns:
140
+ List: Updated chat_history without the last round.
141
+ """
142
+ assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
143
+ return chat_history[:-1]
144
+
145
+ # launch gradio demo
146
+ with gr.Blocks(theme="soft") as demo:
147
+ gr.Markdown("""# Index-1.9B-Character Gradio Demo""")
148
+
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ top_k = gr.Slider(1, 10, value=5, step=1, label="top_k")
152
+ top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
153
+ temperature = gr.Slider(0.1, 2.0, value=0.3, step=0.1, label="temperature")
154
+ repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty")
155
+ max_dec_len = gr.Slider(1, 4096, value=1024, step=1, label="max_dec_len")
156
+ with gr.Row():
157
+ system_message = gr.Textbox(label="System Message", placeholder="Input your system message", value="你是由哔哩哔哩自主研发的大语言模型,名为“Index-1.9B-Character”。你能够根据用户传入的信息,帮助用户完成指定的任务,并生成恰当的、符合要求的回复。")
158
+ with gr.Column(scale=10):
159
+ chatbot = gr.Chatbot(bubble_full_width=False, height=500, label='Index-1.9B-Character')
160
+ user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
161
+ with gr.Row():
162
+ submit = gr.Button("🚀 Submit")
163
+ clear = gr.Button("🧹 Clear")
164
+ regen = gr.Button("🔄 Regenerate")
165
+ reverse = gr.Button("⬅️ Reverse")
166
+
167
+ submit.click(generate, inputs=[chatbot, user_input, top_k, top_p, temperature, repetition_penalty, max_dec_len, system_message],
168
+ outputs=[user_input, chatbot])
169
+ regen.click(regenerate, inputs=[chatbot, top_k, top_p, temperature, repetition_penalty, max_dec_len, system_message],
170
+ outputs=[user_input, chatbot])
171
+ clear.click(clear_history, inputs=[], outputs=[chatbot])
172
+ reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])
173
+
174
+ demo.queue().launch()