aka7774 commited on
Commit
58190b4
1 Parent(s): 81ccdcb

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +104 -119
  2. fn.py +206 -94
  3. main.py +15 -4
  4. requirements.txt +1 -1
app.py CHANGED
@@ -1,133 +1,118 @@
1
  import fn
2
  import gradio as gr
3
- import models
4
-
5
- def fn_chat(instruction, input, model, dtype, is_messages, template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
6
- args = {
7
- 'instruction': instruction,
8
- 'input': input,
9
- 'model': model,
10
- 'dtype': dtype,
11
- 'is_messages': is_messages,
12
- 'template': template,
13
- 'max_new_tokens': int(max_new_tokens),
14
- 'temperature': float(temperature),
15
- 'top_p': float(top_p),
16
- 'top_k': int(top_k),
17
- 'repetition_penalty': float(repetition_penalty),
18
- }
19
-
20
- content = fn.infer(args)
21
- return content
22
 
23
  with gr.Blocks() as demo:
24
- opt = models.get_head_options()
25
-
26
- with gr.Row():
27
- with gr.Column(scale=1):
28
- model = gr.Textbox(
29
- value=opt['model'],
30
- label='model',
31
- show_label=True,
32
- interactive=True,
33
- show_copy_button=True,
34
- )
35
-
36
- dtype = gr.Dropdown(
37
- value=opt['dtype'],
38
- choices=['int4','int8','fp16', 'bf16'],
39
- label='dtype',
40
- show_label=True,
41
- interactive=True,
42
- allow_custom_value=True,
43
- )
44
- template = gr.Textbox(
45
- value=opt['template'],
46
- lines=3,
47
- label='template',
48
- show_label=True,
49
- interactive=True,
50
- show_copy_button=True,
51
- )
52
- is_messages = gr.Checkbox(
53
- value=opt['is_messages'],
54
- label='is_messages',
55
- show_label=True,
56
- interactive=True,
57
- )
58
-
59
- with gr.Column(scale=1):
60
- max_new_tokens = gr.Textbox(
61
- value=opt['max_new_tokens'],
62
- label='max_new_tokens',
63
- show_label=True,
64
- interactive=True,
65
- show_copy_button=True,
66
- )
67
- temperature = gr.Textbox(
68
- value=opt['temperature'],
69
- label='temperature',
70
- show_label=True,
71
- interactive=True,
72
- show_copy_button=True,
73
- )
74
- top_p = gr.Textbox(
75
- value=opt['top_p'],
76
- label='top_p',
77
- show_label=True,
78
- interactive=True,
79
- show_copy_button=True,
80
  )
81
- top_k = gr.Textbox(
82
- value=opt['top_k'],
83
- label='top_k',
84
- show_label=True,
85
- interactive=True,
86
- show_copy_button=True,
87
  )
88
- repetition_penalty = gr.Textbox(
89
- value=opt['repetition_penalty'],
90
- label='repetition_penalty',
91
- show_label=True,
92
- interactive=True,
93
- show_copy_button=True,
94
  )
95
 
96
- with gr.Accordion('Preset', open=False):
97
- gr.Examples(
98
- models.get_examples(),
99
- [model, dtype, is_messages, template, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
100
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- with gr.Row():
103
- with gr.Column(scale=1):
104
- instruction = gr.Textbox(
105
- lines=20,
106
- label='instruction',
107
- show_label=True,
108
- interactive=True,
109
- show_copy_button=True,
110
- )
111
- user_input = gr.Textbox(
112
- lines=1,
113
- label='input',
114
- show_label=True,
115
- interactive=True,
116
- show_copy_button=True,
117
- )
118
- chat_button = gr.Button(value='chat')
119
 
120
- with gr.Column(scale=1):
121
- said = gr.Textbox(
122
- label='said',
123
- lines=15,
124
- show_label=True,
125
- show_copy_button=True,
126
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- chat_button.click(
129
- fn=fn_chat,
130
- inputs=[instruction, user_input, model, dtype, is_messages, template, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
131
  outputs=[said],
132
  )
133
 
 
1
  import fn
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  with gr.Blocks() as demo:
5
+ with gr.Tab('config'):
6
+ info = gr.Markdown()
7
+ with gr.Row():
8
+ with gr.Column(scale=1):
9
+ model = gr.Textbox(
10
+ value=fn.cfg['model_name'],
11
+ label='model',
12
+ interactive=True,
13
+ show_copy_button=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
+ qtype = gr.Dropdown(
16
+ value=fn.cfg['qtype'],
17
+ choices=['bnb','gptq','gguf', 'awq'],
18
+ label='qtype',
19
+ interactive=True,
 
20
  )
21
+ dtype = gr.Dropdown(
22
+ value=fn.cfg['dtype'],
23
+ choices=['4bit','8bit','fp16', 'bf16'],
24
+ label='dtype',
25
+ interactive=True,
26
+ allow_custom_value=True,
27
  )
28
 
29
+ with gr.Column(scale=1):
30
+ max_new_tokens = gr.Textbox(
31
+ value=fn.cfg['max_new_tokens'],
32
+ label='max_new_tokens',
33
+ interactive=True,
34
+ show_copy_button=True,
35
+ )
36
+ temperature = gr.Textbox(
37
+ value=fn.cfg['temperature'],
38
+ label='temperature',
39
+ interactive=True,
40
+ show_copy_button=True,
41
+ )
42
+ top_p = gr.Textbox(
43
+ value=fn.cfg['top_p'],
44
+ label='top_p',
45
+ interactive=True,
46
+ show_copy_button=True,
47
+ )
48
+ top_k = gr.Textbox(
49
+ value=fn.cfg['top_k'],
50
+ label='top_k',
51
+ interactive=True,
52
+ show_copy_button=True,
53
+ )
54
+ repetition_penalty = gr.Textbox(
55
+ value=fn.cfg['repetition_penalty'],
56
+ label='repetition_penalty',
57
+ interactive=True,
58
+ show_copy_button=True,
59
+ )
60
 
61
+ with gr.Row():
62
+ with gr.Column(scale=1):
63
+ inst_template = gr.Textbox(
64
+ value='',
65
+ lines=10,
66
+ label='inst_template',
67
+ interactive=True,
68
+ show_copy_button=True,
69
+ )
70
+ with gr.Column(scale=1):
71
+ chat_template = gr.Textbox(
72
+ value='',
73
+ lines=10,
74
+ label='chat_template',
75
+ interactive=True,
76
+ show_copy_button=True,
77
+ )
78
 
79
+ set_button = gr.Button(value='Save')
80
+
81
+ with gr.Tab('inctruct'):
82
+ with gr.Row():
83
+ with gr.Column(scale=1):
84
+ instruction = gr.Textbox(
85
+ lines=20,
86
+ label='instruction',
87
+ interactive=True,
88
+ show_copy_button=True,
89
+ )
90
+ input = gr.Textbox(
91
+ lines=1,
92
+ label='input',
93
+ interactive=True,
94
+ show_copy_button=True,
95
+ )
96
+ with gr.Column(scale=1):
97
+ said = gr.Textbox(
98
+ label='said',
99
+ lines=25,
100
+ show_copy_button=True,
101
+ )
102
+ inst_button = gr.Button(value='inst')
103
+
104
+ with gr.Tab('chat'):
105
+ gr.ChatInterface(fn.chat)
106
+
107
+ set_button.click(
108
+ fn=fn.set_config,
109
+ inputs=[model, qtype, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
110
+ outputs=[info],
111
+ )
112
 
113
+ inst_button.click(
114
+ fn=fn.chat,
115
+ inputs=[input, input, instruction],
116
  outputs=[said],
117
  )
118
 
fn.py CHANGED
@@ -5,20 +5,32 @@ import datetime
5
  import json
6
  import csv
7
  import gc
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
9
- from transformers import TextStreamer, TextIteratorStreamer
10
- from transformers import GenerationConfig, AutoConfig, GPTQConfig, AwqConfig
11
- from models import models
12
 
13
  tokenizer = None
14
  model = None
15
- loaded_model_name = None
16
- loaded_dtype = None
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def load_model(model_name, dtype = 'int4'):
19
- global tokenizer, model, loaded_model_name, loaded_dtype
20
 
21
- if loaded_model_name == model_name and loaded_dtype == dtype:
22
  return
23
 
24
  del model
@@ -29,100 +41,200 @@ def load_model(model_name, dtype = 'int4'):
29
  torch.cuda.empty_cache()
30
 
31
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- if dtype == 'int4':
34
- model = AutoModelForCausalLM.from_pretrained(
35
- model_name,
36
- device_map="auto",
37
- trust_remote_code=True,
38
- quantization_config=BitsAndBytesConfig(
39
- load_in_4bit=True,
40
- bnb_4bit_compute_dtype=torch.bfloat16,
41
- ),
42
- )
43
- elif dtype == 'int8':
44
- model = AutoModelForCausalLM.from_pretrained(
45
  model_name,
46
  device_map="auto",
47
  trust_remote_code=True,
48
- quantization_config=BitsAndBytesConfig(
49
- torch_dtype=torch.bfloat16,
50
- load_in_8bit=True,
51
- ),
52
- )
53
- elif dtype == 'fp16':
54
- model = AutoModelForCausalLM.from_pretrained(
55
- model_name,
56
- device_map="auto",
57
- trust_remote_code=True,
58
- torch_dtype=torch.float16,
59
- )
60
- elif dtype == 'bf16':
61
- model = AutoModelForCausalLM.from_pretrained(
62
- model_name,
63
- device_map="auto",
64
- trust_remote_code=True,
65
- torch_dtype=torch.bfloat16,
66
- )
67
- else:
68
- model = AutoModelForCausalLM.from_pretrained(
69
- model_name,
70
- trust_remote_code=True,
71
- device_map="auto",
72
  )
73
 
74
- loaded_model_name = model_name
75
- loaded_dtype = dtype
 
76
 
77
- def infer(args: dict):
78
- global tokenizer, model, loaded_model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- if 'model' in args:
81
- args['model_name'] = args['model']
82
 
83
- if not tokenizer or 'model_name' in args and loaded_model_name != args['model_name']:
84
- if 'dtype' in args:
85
- load_model(args['model_name'], args['dtype'])
86
- else:
87
- load_model(args['model_name'])
88
-
89
- config = {}
90
- if args['model_name'] in models:
91
- config = models[args['model_name']]
92
- config.update(args)
93
-
94
- if config['is_messages']:
95
- messages = []
96
- messages.append({"role": "system", "content": args['instruction']})
97
- if args['input']:
98
- messages.append({"role": "user", "content": args['input']})
99
- tprompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
 
100
  else:
101
- tprompt = config['template'].format(bos_token=tokenizer.bos_token, instruction=args['instruction'], input=args['input'])
102
-
103
- kwargs = config.copy()
104
- for k in ['model_name', 'template', 'instruction', 'input', 'location', 'endpoint', 'model', 'dtype', 'is_messages']:
105
- if k in kwargs:
106
- del kwargs[k]
107
-
108
- with torch.no_grad():
109
- token_ids = tokenizer.encode(tprompt, add_special_tokens=False, return_tensors="pt")
110
- if config['is_messages']:
111
- output_ids = model.generate(
112
- input_ids=token_ids.to(model.device),
113
- do_sample=True,
114
- **kwargs,
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  else:
117
- output_ids = model.generate(
118
- input_ids=token_ids.to(model.device),
119
- do_sample=True,
120
- pad_token_id=tokenizer.pad_token_id,
121
- bos_token_id=tokenizer.bos_token_id,
122
- eos_token_id=tokenizer.eos_token_id,
123
- **kwargs,
124
- )
125
- out = output_ids.tolist()[0][token_ids.size(1) :]
126
- content = tokenizer.decode(out, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- return content
 
 
 
 
 
 
5
  import json
6
  import csv
7
  import gc
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from transformers import TextIteratorStreamer
10
+ from transformers import BitsAndBytesConfig, GPTQConfig
11
+ from threading import Thread
12
 
13
  tokenizer = None
14
  model = None
15
+ default_cfg = {
16
+ 'model_name': None,
17
+ 'qtype': 'bnb',
18
+ 'dtype': '4bit',
19
+ 'instruction': None,
20
+ 'inst_template': None,
21
+ 'chat_template': None,
22
+ 'max_new_tokens': 1024,
23
+ 'temperature': 0.9,
24
+ 'top_p': 0.95,
25
+ 'top_k': 40,
26
+ 'repetition_penalty': 1.2,
27
+ }
28
+ cfg = default_cfg.copy()
29
 
30
+ def load_model(model_name, qtype = 'bnb', dtype = '4bit'):
31
+ global tokenizer, model, cfg
32
 
33
+ if cfg['model_name'] == model_name and cfg['qtype'] == qtype and cfg['dtype'] == dtype:
34
  return
35
 
36
  del model
 
41
  torch.cuda.empty_cache()
42
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
44
+
45
+ match qtype:
46
+ case 'bnb':
47
+ match dtype:
48
+ case '4bit' | 'int4':
49
+ kwargs = dict(
50
+ quantization_config=BitsAndBytesConfig(
51
+ load_in_4bit=True,
52
+ bnb_4bit_compute_dtype=torch.bfloat16,
53
+ ),
54
+ )
55
+ case '8bit' | 'int8':
56
+ kwargs = dict(
57
+ quantization_config=BitsAndBytesConfig(
58
+ load_in_8bit=True,
59
+ bnb_4bit_compute_dtype=torch.bfloat16,
60
+ ),
61
+ )
62
+ case 'fp16':
63
+ kwargs = dict(
64
+ torch_dtype=torch.float16,
65
+ )
66
+ case 'bf16':
67
+ kwargs = dict(
68
+ torch_dtype=torch.bfloat16,
69
+ )
70
+ case _:
71
+ kwargs = dict()
72
+ case 'gptq':
73
+ match dtype:
74
+ case '4bit' | 'int4':
75
+ kwargs = dict(
76
+ quantization_config=GPTQConfig(
77
+ bits=4,
78
+ tokenizer=tokenizer,
79
+ ),
80
+ )
81
+ case '8bit' | 'int8':
82
+ kwargs = dict(
83
+ quantization_config=GPTQConfig(
84
+ bits=8,
85
+ tokenizer=tokenizer,
86
+ ),
87
+ )
88
+ case 'gguf':
89
+ kwargs = dict(
90
+ gguf_file=qtype,
91
+ )
92
+ case 'awq':
93
+ match dtype:
94
+ case 'fa2':
95
+ kwargs = dict(
96
+ use_flash_attention_2=True,
97
+ )
98
+ case _:
99
+ kwargs = dict()
100
 
101
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
102
  model_name,
103
  device_map="auto",
104
  trust_remote_code=True,
105
+ **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
 
108
+ cfg['model_name'] = model_name
109
+ cfg['qtype'] = qtype
110
+ cfg['dtype'] = dtype
111
 
112
+ def clear_config():
113
+ global cfg
114
+ cfg = default_cfg.copy()
115
+
116
+ def set_config(model_name, qtype, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
117
+ global cfg
118
+ load_model(model_name, qtype, dtype)
119
+ cfg.update({
120
+ 'instruction': instruction,
121
+ 'inst_template': inst_template,
122
+ 'chat_template': chat_template,
123
+ 'max_new_tokens': int(max_new_tokens),
124
+ 'temperature': float(temperature),
125
+ 'top_p': float(top_p),
126
+ 'top_k': int(top_k),
127
+ 'repetition_penalty': float(repetition_penalty),
128
+ })
129
+ return 'done.'
130
+
131
+ def set_config_args(args):
132
+ global cfg
133
+
134
+ load_model(args['model_name'], args['qtype'], args['dtype'])
135
+ cfg.update(args)
136
+
137
+ return 'done.'
138
+
139
+ def chatinterface_to_messages(message, history):
140
+ global cfg
141
+
142
+ messages = []
143
 
144
+ if cfg['instruction']:
145
+ messages.append({'role': 'system', 'content': cfg['instruction']})
146
 
147
+ for pair in history:
148
+ [user, assistant] = pair
149
+ if user:
150
+ messages.append({'role': 'user', 'content': user})
151
+ if assistant:
152
+ messages.append({'role': 'assistant', 'content': assistant})
153
+
154
+ if message:
155
+ messages.append({'role': 'user', 'content': message})
156
+
157
+ return messages
158
+
159
+ def chat(message, history = [], instruction = None, args = {}):
160
+ global tokenizer, model, cfg
161
+
162
+ if instruction:
163
+ cfg['instruction'] = instruction
164
+ prompt = apply_template(message)
165
  else:
166
+ messages = chatinterface_to_messages(message, history)
167
+ prompt = apply_template(messages)
168
+
169
+ model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
170
+
171
+ streamer = TextIteratorStreamer(
172
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
173
+ )
174
+
175
+ generate_kwargs = dict(
176
+ model_inputs,
177
+ streamer=streamer,
178
+ do_sample=True,
179
+ num_beams=1,
180
+ )
181
+ for k in [
182
+ 'max_new_tokens',
183
+ 'temperature',
184
+ 'top_p',
185
+ 'top_k',
186
+ 'repetition_penalty'
187
+ ]:
188
+ if cfg[k]:
189
+ generate_kwargs[k] = cfg[k]
190
+
191
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
192
+ t.start()
193
+
194
+ model_output = ""
195
+ for new_text in streamer:
196
+ model_output += new_text
197
+ if 'fastapi' in args:
198
+ # fastapiは差分だけを返して欲しい
199
+ yield new_text
200
  else:
201
+ # gradioは常に全文を返して欲しい
202
+ yield model_output
203
+
204
+ return model_output
205
+
206
+ def infer(args: dict):
207
+ global cfg
208
+
209
+ if 'model_name' in args:
210
+ load_model(args['model_name'], args['qtype'], args['dtype'])
211
+
212
+ for k in [
213
+ 'instruction',
214
+ 'inst_template',
215
+ 'chat_template',
216
+ 'max_new_tokens',
217
+ 'temperature',
218
+ 'top_p',
219
+ 'top_k',
220
+ 'repetition_penalty'
221
+ ]:
222
+ cfg[k] = args[k]
223
+
224
+ if 'messages' in args:
225
+ return chat(args['input'], args['messages'])
226
+ if 'instruction' in args:
227
+ return instruct(args['instruction'], args['input'])
228
+
229
+ def apply_template(messages):
230
+ global tokenizer, cfg
231
+
232
+ if cfg['chat_template']:
233
+ tokenizer.chat_template = cfg['chat_template']
234
 
235
+ if type(messages) is str:
236
+ if cfg['inst_template']:
237
+ return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
238
+ return cfg['instruction']
239
+ if type(messages) is list:
240
+ return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
main.py CHANGED
@@ -9,8 +9,7 @@ from fastapi.staticfiles import StaticFiles
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel, Field
11
  from fastapi.exceptions import RequestValidationError
12
- from fastapi.responses import JSONResponse
13
-
14
  import fn
15
  import gradio as gr
16
  from app import demo
@@ -27,7 +26,19 @@ app.add_middleware(
27
 
28
  gr.mount_gradio_app(app, demo, path="/gradio")
29
 
 
 
 
 
 
30
  @app.post("/infer")
31
  async def api_infer(args: dict):
32
- content = fn.infer(args)
33
- return {'content': content}
 
 
 
 
 
 
 
 
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel, Field
11
  from fastapi.exceptions import RequestValidationError
12
+ from fastapi.responses import JSONResponse, StreamingResponse
 
13
  import fn
14
  import gradio as gr
15
  from app import demo
 
26
 
27
  gr.mount_gradio_app(app, demo, path="/gradio")
28
 
29
+ @app.post("/set_config")
30
+ async def api_set_config(args: dict):
31
+ content = fn.set_config_args(args)
32
+ return {'content': content}
33
+
34
  @app.post("/infer")
35
  async def api_infer(args: dict):
36
+ args['fastapi'] = True
37
+ if 'stream' in args and args['stream']:
38
+ return StreamingResponse(
39
+ fn.chat(args['input'], [], args['instruct'], args),
40
+ media_type="text/event-stream",
41
+ )
42
+ else:
43
+ content = fn.chat(args['input'], [], args['instruct'], args)
44
+ return {'content': content}
requirements.txt CHANGED
@@ -4,9 +4,9 @@ transformers
4
  accelerate
5
  sentencepiece
6
  bitsandbytes
 
7
  scipy
8
  tiktoken
9
  einops
10
- transformers_stream_generator
11
  protobuf
12
  python-multipart
 
4
  accelerate
5
  sentencepiece
6
  bitsandbytes
7
+ autoawq
8
  scipy
9
  tiktoken
10
  einops
 
11
  protobuf
12
  python-multipart