Vic-729 commited on
Commit
b1d4399
·
1 Parent(s): fb2c9a5
Files changed (2) hide show
  1. app.py +291 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+
4
+ We mainly modified part of the code logic to adapt to the
5
+ generation of our model.
6
+ Please refer to these links below for more information:
7
+ 1. streamlit chat example:
8
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
9
+ 2. chatglm2:
10
+ https://github.com/THUDM/ChatGLM2-6B
11
+ 3. transformers:
12
+ https://github.com/huggingface/transformers
13
+ Please run with the command `streamlit run path/to/web_demo.py
14
+ --server.address=0.0.0.0 --server.port 7860`.
15
+ Using `python path/to/web_demo.py` may cause unknown problems.
16
+ """
17
+ # isort: skip_file
18
+ import copy
19
+ import warnings
20
+ from dataclasses import asdict, dataclass
21
+ from typing import Callable, List, Optional
22
+ import streamlit as st
23
+ import torch
24
+ from torch import nn
25
+ from transformers.generation.utils import (LogitsProcessorList,
26
+ StoppingCriteriaList)
27
+ from transformers.utils import logging
28
+
29
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
30
+
31
+ logger = logging.get_logger(__name__)
32
+ model_name_or_path = "Vic-729/merged"
33
+
34
+ @dataclass
35
+ class GenerationConfig:
36
+ # this config is used for chat to provide more diversity
37
+ max_length: int = 32768
38
+ top_p: float = 0.8
39
+ temperature: float = 0.8
40
+ do_sample: bool = True
41
+ repetition_penalty: float = 1.005
42
+
43
+
44
+ @torch.inference_mode()
45
+ def generate_interactive(
46
+ model,
47
+ tokenizer,
48
+ prompt,
49
+ generation_config: Optional[GenerationConfig] = None,
50
+ logits_processor: Optional[LogitsProcessorList] = None,
51
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
52
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
53
+ List[int]]] = None,
54
+ additional_eos_token_id: Optional[int] = None,
55
+ **kwargs,
56
+ ):
57
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
58
+ input_length = len(inputs['input_ids'][0])
59
+ for k, v in inputs.items():
60
+ inputs[k] = v.cuda()
61
+ input_ids = inputs['input_ids']
62
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
63
+ if generation_config is None:
64
+ generation_config = model.generation_config
65
+ generation_config = copy.deepcopy(generation_config)
66
+ model_kwargs = generation_config.update(**kwargs)
67
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
68
+ generation_config.bos_token_id,
69
+ generation_config.eos_token_id,
70
+ )
71
+ if isinstance(eos_token_id, int):
72
+ eos_token_id = [eos_token_id]
73
+ if additional_eos_token_id is not None:
74
+ eos_token_id.append(additional_eos_token_id)
75
+ has_default_max_length = kwargs.get(
76
+ 'max_length') is None and generation_config.max_length is not None
77
+ if has_default_max_length and generation_config.max_new_tokens is None:
78
+ warnings.warn(
79
+ f"Using 'max_length''s default \
80
+ ({repr(generation_config.max_length)}) \
81
+ to control the generation length. "
82
+ 'This behaviour is deprecated and will be removed from the \
83
+ config in v5 of Transformers -- we'
84
+ ' recommend using `max_new_tokens` to control the maximum \
85
+ length of the generation.',
86
+ UserWarning,
87
+ )
88
+ elif generation_config.max_new_tokens is not None:
89
+ generation_config.max_length = generation_config.max_new_tokens + \
90
+ input_ids_seq_length
91
+ if not has_default_max_length:
92
+ logger.warn( # pylint: disable=W4902
93
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
94
+ f"and 'max_length'(={generation_config.max_length}) seem to "
95
+ "have been set. 'max_new_tokens' will take precedence. "
96
+ 'Please refer to the documentation for more information. '
97
+ '(https://huggingface.co/docs/transformers/main/'
98
+ 'en/main_classes/text_generation)',
99
+ UserWarning,
100
+ )
101
+
102
+ if input_ids_seq_length >= generation_config.max_length:
103
+ input_ids_string = 'input_ids'
104
+ logger.warning(
105
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
106
+ f"but 'max_length' is set to {generation_config.max_length}. "
107
+ 'This can lead to unexpected behavior. You should consider'
108
+ " increasing 'max_new_tokens'.")
109
+
110
+ # 2. Set generation parameters if not already defined
111
+ logits_processor = logits_processor if logits_processor is not None \
112
+ else LogitsProcessorList()
113
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
114
+ else StoppingCriteriaList()
115
+
116
+ logits_processor = model._get_logits_processor(
117
+ generation_config=generation_config,
118
+ input_ids_seq_length=input_ids_seq_length,
119
+ encoder_input_ids=input_ids,
120
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
121
+ logits_processor=logits_processor,
122
+ )
123
+
124
+ stopping_criteria = model._get_stopping_criteria(
125
+ generation_config=generation_config,
126
+ stopping_criteria=stopping_criteria)
127
+ logits_warper = model._get_logits_warper(generation_config)
128
+
129
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
130
+ scores = None
131
+ while True:
132
+ model_inputs = model.prepare_inputs_for_generation(
133
+ input_ids, **model_kwargs)
134
+ # forward pass to get next token
135
+ outputs = model(
136
+ **model_inputs,
137
+ return_dict=True,
138
+ output_attentions=False,
139
+ output_hidden_states=False,
140
+ )
141
+
142
+ next_token_logits = outputs.logits[:, -1, :]
143
+
144
+ # pre-process distribution
145
+ next_token_scores = logits_processor(input_ids, next_token_logits)
146
+ next_token_scores = logits_warper(input_ids, next_token_scores)
147
+
148
+ # sample
149
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
150
+ if generation_config.do_sample:
151
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
152
+ else:
153
+ next_tokens = torch.argmax(probs, dim=-1)
154
+
155
+ # update generated ids, model inputs, and length for next step
156
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
157
+ model_kwargs = model._update_model_kwargs_for_generation(
158
+ outputs, model_kwargs, is_encoder_decoder=False)
159
+ unfinished_sequences = unfinished_sequences.mul(
160
+ (min(next_tokens != i for i in eos_token_id)).long())
161
+
162
+ output_token_ids = input_ids[0].cpu().tolist()
163
+ output_token_ids = output_token_ids[input_length:]
164
+ for each_eos_token_id in eos_token_id:
165
+ if output_token_ids[-1] == each_eos_token_id:
166
+ output_token_ids = output_token_ids[:-1]
167
+ response = tokenizer.decode(output_token_ids)
168
+
169
+ yield response
170
+ # stop when each sentence is finished
171
+ # or if we exceed the maximum length
172
+ if unfinished_sequences.max() == 0 or stopping_criteria(
173
+ input_ids, scores):
174
+ break
175
+
176
+
177
+ def on_btn_click():
178
+ del st.session_state.messages
179
+
180
+
181
+ @st.cache_resource
182
+ def load_model():
183
+ model = (AutoModelForCausalLM.from_pretrained(
184
+ model_name_or_path,
185
+ trust_remote_code=True).to(torch.bfloat16).cuda())
186
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
187
+ trust_remote_code=True)
188
+ return model, tokenizer
189
+
190
+
191
+ def prepare_generation_config():
192
+ with st.sidebar:
193
+ max_length = st.slider('Max Length',
194
+ min_value=8,
195
+ max_value=32768,
196
+ value=32768)
197
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
198
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
199
+ st.button('Clear Chat History', on_click=on_btn_click)
200
+
201
+ generation_config = GenerationConfig(max_length=max_length,
202
+ top_p=top_p,
203
+ temperature=temperature)
204
+
205
+ return generation_config
206
+
207
+
208
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
209
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
210
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
211
+ <|im_start|>assistant\n'
212
+
213
+
214
+ def combine_history(prompt):
215
+ messages = st.session_state.messages
216
+ meta_instruction = ('You are a helpful, honest, '
217
+ 'and harmless AI assistant.')
218
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
219
+ for message in messages:
220
+ cur_content = message['content']
221
+ if message['role'] == 'user':
222
+ cur_prompt = user_prompt.format(user=cur_content)
223
+ elif message['role'] == 'robot':
224
+ cur_prompt = robot_prompt.format(robot=cur_content)
225
+ else:
226
+ raise RuntimeError
227
+ total_prompt += cur_prompt
228
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
229
+ return total_prompt
230
+
231
+
232
+ def main():
233
+ st.title('internlm2_5-7b-chat-assistant')
234
+
235
+ # torch.cuda.empty_cache()
236
+ print('load model begin.')
237
+ model, tokenizer = load_model()
238
+ print('load model end.')
239
+
240
+ generation_config = prepare_generation_config()
241
+
242
+ # Initialize chat history
243
+ if 'messages' not in st.session_state:
244
+ st.session_state.messages = []
245
+
246
+ # Display chat messages from history on app rerun
247
+ for message in st.session_state.messages:
248
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
249
+ st.markdown(message['content'])
250
+
251
+ # Accept user input
252
+ if prompt := st.chat_input('What is up?'):
253
+ # Display user message in chat message container
254
+
255
+ with st.chat_message('user', avatar='user'):
256
+
257
+ st.markdown(prompt)
258
+ real_prompt = combine_history(prompt)
259
+ # Add user message to chat history
260
+ st.session_state.messages.append({
261
+ 'role': 'user',
262
+ 'content': prompt,
263
+ 'avatar': 'user'
264
+ })
265
+
266
+ with st.chat_message('robot', avatar='assistant'):
267
+
268
+ message_placeholder = st.empty()
269
+ for cur_response in generate_interactive(
270
+ model=model,
271
+ tokenizer=tokenizer,
272
+ prompt=real_prompt,
273
+ additional_eos_token_id=92542,
274
+ device='cuda:0',
275
+ **asdict(generation_config),
276
+ ):
277
+ # Display robot response in chat message container
278
+ message_placeholder.markdown(cur_response + '▌')
279
+ message_placeholder.markdown(cur_response)
280
+ # Add robot response to chat history
281
+ st.session_state.messages.append({
282
+ 'role': 'robot',
283
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
284
+ 'avatar': 'assistant',
285
+ })
286
+ torch.cuda.empty_cache()
287
+
288
+
289
+ if __name__ == '__main__':
290
+ main()
291
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.7.0
2
+ protobuf==5.26.1
3
+ llama-index==0.11.20
4
+ llama-index-llms-replicate==0.3.0
5
+ llama-index-llms-openai-like==0.2.0
6
+ llama-index-embeddings-huggingface==0.3.1
7
+ llama-index-embeddings-instructor==0.2.1
8
+ torch==2.5.0
9
+ torchvision==0.20.0
10
+ torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu121