aka7774 commited on
Commit
a0226d6
1 Parent(s): 8dc6a10

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +13 -1
  2. fn.py +17 -0
  3. main.py +5 -0
app.py CHANGED
@@ -96,10 +96,16 @@ with gr.Blocks() as demo:
96
  with gr.Column(scale=1):
97
  said = gr.Textbox(
98
  label='said',
99
- lines=25,
 
 
 
 
 
100
  show_copy_button=True,
101
  )
102
  inst_button = gr.Button(value='inst')
 
103
 
104
  with gr.Tab('chat'):
105
  gr.ChatInterface(fn.chat)
@@ -116,5 +122,11 @@ with gr.Blocks() as demo:
116
  outputs=[said],
117
  )
118
 
 
 
 
 
 
 
119
  if __name__ == '__main__':
120
  demo.launch()
 
96
  with gr.Column(scale=1):
97
  said = gr.Textbox(
98
  label='said',
99
+ lines=20,
100
+ show_copy_button=True,
101
+ )
102
+ numel = gr.Textbox(
103
+ lines=1,
104
+ label='numel',
105
  show_copy_button=True,
106
  )
107
  inst_button = gr.Button(value='inst')
108
+ numel_button = gr.Button(value='numel')
109
 
110
  with gr.Tab('chat'):
111
  gr.ChatInterface(fn.chat)
 
122
  outputs=[said],
123
  )
124
 
125
+ numel_button.click(
126
+ fn=fn.numel,
127
+ inputs=[input, input, instruction],
128
+ outputs=[numel],
129
+ )
130
+
131
  if __name__ == '__main__':
132
  demo.launch()
fn.py CHANGED
@@ -180,6 +180,9 @@ def chat(message, history = [], instruction = None, args = {}):
180
  prompt = apply_template(messages)
181
 
182
  model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
 
 
 
183
 
184
  streamer = TextIteratorStreamer(
185
  tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
@@ -220,3 +223,17 @@ def infer(message, history = [], instruction = None, args = {}):
220
  for s in chat(message, history, instruction, args):
221
  content += s
222
  return content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  prompt = apply_template(messages)
181
 
182
  model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
183
+ # どうせ0固定だしエラーが出るので消してしまう
184
+ if 'token_type_ids' in model_inputs:
185
+ del model_inputs['token_type_ids']
186
 
187
  streamer = TextIteratorStreamer(
188
  tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
 
223
  for s in chat(message, history, instruction, args):
224
  content += s
225
  return content
226
+
227
+ def numel(message, history = [], instruction = None, args = {}):
228
+ global tokenizer, model, cfg
229
+
230
+ if instruction:
231
+ cfg['instruction'] = instruction
232
+ prompt = apply_template(message)
233
+ else:
234
+ messages = chatinterface_to_messages(message, history)
235
+ prompt = apply_template(messages)
236
+
237
+ model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
238
+
239
+ return torch.numel(model_inputs['input_ids'])
main.py CHANGED
@@ -42,3 +42,8 @@ async def api_infer(args: dict):
42
  else:
43
  content = fn.infer(args['input'], [], args['instruct'], args)
44
  return {'content': content}
 
 
 
 
 
 
42
  else:
43
  content = fn.infer(args['input'], [], args['instruct'], args)
44
  return {'content': content}
45
+
46
+ @app.post("/numel")
47
+ async def api_numel(args: dict):
48
+ content = fn.numel(args['input'], [], args['instruct'], args)
49
+ return {'numel': content}