Arturo Jiménez de los Galanes Reguillos commited on
Commit
4157280
1 Parent(s): 74704c7

Try to handle inputs slightly different

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -36,12 +36,13 @@ def translate(python, progress=gr.Progress()):
36
  messages_for(python),
37
  tokenize=False,
38
  add_generation_prompt=True,
39
- return_tensors="pt").to(model.device)
40
- inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True)
41
  attention_mask = inputs["attention_mask"]
 
42
 
43
  outputs = model.generate(
44
- inputs['input_ids'],
45
  attention_mask=attention_mask,
46
  max_new_tokens=1024,
47
  do_sample=False,
@@ -49,7 +50,7 @@ def translate(python, progress=gr.Progress()):
49
  eos_token_id=tokenizer.eos_token_id,
50
  )
51
  progress(1, desc="Finished")
52
- return tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
53
  '''
54
  generation_kwargs = dict(
55
  inputs,
 
36
  messages_for(python),
37
  tokenize=False,
38
  add_generation_prompt=True,
39
+ return_tensors="pt")
40
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True).to(model.device)
41
  attention_mask = inputs["attention_mask"]
42
+ input_ids = inputs['input_ids']
43
 
44
  outputs = model.generate(
45
+ input_ids,
46
  attention_mask=attention_mask,
47
  max_new_tokens=1024,
48
  do_sample=False,
 
50
  eos_token_id=tokenizer.eos_token_id,
51
  )
52
  progress(1, desc="Finished")
53
+ return tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
54
  '''
55
  generation_kwargs = dict(
56
  inputs,