mtensor commited on
Commit
3856d86
1 Parent(s): 8c4b5c5

only decode the end of the generation

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -64,8 +64,8 @@ for k, v in model_inputs.items():
64
  model_inputs[k] = v.to("cuda:0")
65
 
66
  generation_output = model.generate(**model_inputs, max_new_tokens=7)
67
- generation_text = processor.batch_decode(generation_output, skip_special_tokens=True)[0][-35:]
68
- assert generation_text == "A bus parked on the side of a road."
69
  ```
70
 
71
  Fuyu can also perform some question answering on natural images and charts/diagrams (thought fine-tuning may be required for good performance):
@@ -79,8 +79,8 @@ for k, v in model_inputs.items():
79
  model_inputs[k] = v.to("cuda:0")
80
 
81
  generation_output = model.generate(**model_inputs, max_new_tokens=6)
82
- generation_text = processor.batch_decode(generation_output, skip_special_tokens=True)[0][-17:]
83
- assert generation_text == "The bus is blue.\n"
84
 
85
 
86
  text_prompt = "What is the highest life expectancy at birth of male?\n"
@@ -92,8 +92,8 @@ for k, v in model_inputs.items():
92
  model_inputs[k] = v.to("cuda:0")
93
 
94
  generation_output = model.generate(**model_inputs, max_new_tokens=16)
95
- generation_text = processor.batch_decode(generation_output, skip_special_tokens=True)[0][-55:]
96
- assert generation_text == "The life expectancy at birth of males in 2018 is 80.7.\n"
97
  ```
98
 
99
  ## Uses
 
64
  model_inputs[k] = v.to("cuda:0")
65
 
66
  generation_output = model.generate(**model_inputs, max_new_tokens=7)
67
+ generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)
68
+ assert generation_text == ['A bus parked on the side of a road.']
69
  ```
70
 
71
  Fuyu can also perform some question answering on natural images and charts/diagrams (thought fine-tuning may be required for good performance):
 
79
  model_inputs[k] = v.to("cuda:0")
80
 
81
  generation_output = model.generate(**model_inputs, max_new_tokens=6)
82
+ generation_text = processor.batch_decode(generation_output[:, -6:], skip_special_tokens=True)
83
+ assert generation_text == ["The bus is blue.\n"]
84
 
85
 
86
  text_prompt = "What is the highest life expectancy at birth of male?\n"
 
92
  model_inputs[k] = v.to("cuda:0")
93
 
94
  generation_output = model.generate(**model_inputs, max_new_tokens=16)
95
+ generation_text = processor.batch_decode(generation_output[:, -16:], skip_special_tokens=True)
96
+ assert generation_text == ["The life expectancy at birth of males in 2018 is 80.7.\n"]
97
  ```
98
 
99
  ## Uses