aka7774 commited on
Commit
875d37f
1 Parent(s): bc16f24

Upload 2 files

Browse files
Files changed (2) hide show
  1. fn.py +61 -32
  2. main.py +10 -10
fn.py CHANGED
@@ -86,7 +86,8 @@ def chatinterface_to_messages(message, history):
86
  messages = []
87
 
88
  if cfg['instruction']:
89
- messages.append({'role': 'system', 'content': cfg['instruction']})
 
90
 
91
  for pair in history:
92
  [user, assistant] = pair
@@ -100,32 +101,43 @@ def chatinterface_to_messages(message, history):
100
 
101
  return messages
102
 
103
- def apply_template(messages):
104
  global tokenizer, cfg
105
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if cfg['chat_template']:
107
  tokenizer.chat_template = cfg['chat_template']
108
 
109
- if type(messages) is str:
110
  if cfg['inst_template']:
111
- return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
112
- messages = [
113
- {'role': 'user', 'content': cfg['instruction']},
114
- {'role': 'assistant', 'content': 'I understand.'},
115
- {'role': 'user', 'content': messages},
116
- ]
117
- if type(messages) is list:
118
- return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
119
-
120
- def chat(message, history = [], instruction = None, args = {}):
 
 
 
 
121
  global tokenizer, model, cfg
122
 
123
- if instruction:
124
- cfg['instruction'] = instruction
125
- prompt = apply_template(message)
126
- else:
127
- messages = chatinterface_to_messages(message, history)
128
- prompt = apply_template(messages)
129
 
130
  inputs = tokenizer(prompt, return_tensors="pt",
131
  padding=True, max_length=cfg['max_length'], truncation=True).to("cuda")
@@ -164,23 +176,40 @@ def chat(message, history = [], instruction = None, args = {}):
164
  # gradioは常に全文を返して欲しい
165
  yield model_output
166
 
167
- def infer(message, history = [], instruction = None, args = {}):
168
- content = ''
169
- for s in chat(message, history, instruction, args):
170
- content += s
171
- return content
 
 
172
 
173
- def numel(message, history = [], instruction = None, args = {}):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  global tokenizer, model, cfg
175
 
176
- if instruction:
177
- cfg['instruction'] = instruction
178
- prompt = apply_template(message)
179
- else:
180
- messages = chatinterface_to_messages(message, history)
181
- prompt = apply_template(messages)
182
 
183
- model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
184
 
185
  return torch.numel(model_inputs['input_ids'])
186
 
 
86
  messages = []
87
 
88
  if cfg['instruction']:
89
+ messages.append({'role': 'user', 'content': cfg['instruction']})
90
+ messages.append({'role': 'assistant', 'content': 'I understand.'})
91
 
92
  for pair in history:
93
  [user, assistant] = pair
 
101
 
102
  return messages
103
 
104
+ def apply_template(message, history, args):
105
  global tokenizer, cfg
106
 
107
+ if 'input' in args:
108
+ message = args['input']
109
+ if 'instruction' in args:
110
+ cfg['instruction'] = args['instruction']
111
+
112
+ if 'messages' in args:
113
+ messages = args['messages']
114
+ elif history:
115
+ messages = chatinterface_to_messages(message, history)
116
+ else:
117
+ messages = {}
118
+
119
  if cfg['chat_template']:
120
  tokenizer.chat_template = cfg['chat_template']
121
 
122
+ if message:
123
  if cfg['inst_template']:
124
+ return cfg['inst_template'].format(instruction=cfg['instruction'], input=message)
125
+ if cfg['instruction']:
126
+ messages = [
127
+ {'role': 'user', 'content': cfg['instruction']},
128
+ {'role': 'assistant', 'content': 'I understand.'},
129
+ {'role': 'user', 'content': messages},
130
+ ]
131
+ else:
132
+ messages = [
133
+ {'role': 'user', 'content': message},
134
+ ]
135
+ return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
136
+
137
+ def chat(message = None, history = [], args = {}):
138
  global tokenizer, model, cfg
139
 
140
+ prompt = apply_template(message, history, args)
 
 
 
 
 
141
 
142
  inputs = tokenizer(prompt, return_tensors="pt",
143
  padding=True, max_length=cfg['max_length'], truncation=True).to("cuda")
 
176
  # gradioは常に全文を返して欲しい
177
  yield model_output
178
 
179
+ def infer(message = None, history = [], args = {}):
180
+ global tokenizer, model, cfg
181
+
182
+ prompt = apply_template(message, history, args)
183
+
184
+ inputs = tokenizer(prompt, return_tensors="pt",
185
+ padding=True, max_length=cfg['max_length'], truncation=True).to("cuda")
186
 
187
+ generate_kwargs = dict(
188
+ inputs,
189
+ do_sample=True,
190
+ num_beams=1,
191
+ use_cache=True,
192
+ )
193
+
194
+ for k in [
195
+ 'max_new_tokens',
196
+ 'temperature',
197
+ 'top_p',
198
+ 'top_k',
199
+ 'repetition_penalty'
200
+ ]:
201
+ if cfg[k]:
202
+ generate_kwargs[k] = cfg[k]
203
+
204
+ output_ids = model.generate(**generate_kwargs)
205
+ return tokenizer.decode(output_ids.tolist()[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
206
+
207
+ def numel(message = None, history = [], args = {}):
208
  global tokenizer, model, cfg
209
 
210
+ prompt = apply_template(message, history, args)
 
 
 
 
 
211
 
212
+ model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
213
 
214
  return torch.numel(model_inputs['input_ids'])
215
 
main.py CHANGED
@@ -33,17 +33,17 @@ async def api_set_config(args: dict):
33
 
34
  @app.post("/infer")
35
  async def api_infer(args: dict):
36
- args['fastapi'] = True
37
- if 'stream' in args and args['stream']:
38
- return StreamingResponse(
39
- fn.chat(args['input'], [], args['instruct'], args),
40
- media_type="text/event-stream",
41
- )
42
- else:
43
- content = fn.infer(args['input'], [], args['instruct'], args)
44
- return {'content': content}
45
 
46
  @app.post("/numel")
47
  async def api_numel(args: dict):
48
- content = fn.numel(args['input'], [], args['instruct'], args)
49
  return {'numel': content}
 
33
 
34
  @app.post("/infer")
35
  async def api_infer(args: dict):
36
+ content = fn.infer(args=args)
37
+ return {'content': content}
38
+
39
+ @app.post("/stream")
40
+ async def api_stream(args: dict):
41
+ return StreamingResponse(
42
+ fn.chat(args=args),
43
+ media_type="text/event-stream",
44
+ )
45
 
46
  @app.post("/numel")
47
  async def api_numel(args: dict):
48
+ content = fn.numel(args=args)
49
  return {'numel': content}