xxyyy123 commited on
Commit
1418e34
·
verified ·
1 Parent(s): 68fded4

Update README.md

Browse files

Add batch inference example.

Files changed (1) hide show
  1. README.md +52 -0
README.md CHANGED
@@ -78,6 +78,58 @@ with torch.inference_mode():
78
  print(f'Output:\n{output}')
79
  ```
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  ## Citation
82
  If you find Ovis useful, please cite the paper
83
  ```
 
78
  print(f'Output:\n{output}')
79
  ```
80
 
81
+ <details>
82
+ <summary>Batch inference</summary>
83
+
84
+ ```python
85
+ batch_inputs = [
86
+ ('example_image1.jpeg', 'Describe the content of this image.'),
87
+ ('example_image2.jpeg', 'What is the equation in the image?')
88
+ ]
89
+
90
+ batch_input_ids = []
91
+ batch_attention_mask = []
92
+ batch_pixel_values = []
93
+
94
+ for image_path, text in batch_inputs:
95
+ image = Image.open(image_path)
96
+ query = f'<image>\n{text}'
97
+ prompt, input_ids, pixel_values = model.preprocess_inputs(query, [image])
98
+ attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
99
+ input_ids = input_ids.unsqueeze(0).to(device=model.device)
100
+ attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
101
+ pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
102
+ batch_input_ids.append(input_ids.squeeze())
103
+ batch_attention_mask.append(attention_mask.squeeze())
104
+ batch_pixel_values.append(pixel_values)
105
+
106
+ pad_batch_input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in batch_input_ids],batch_first=True, padding_value=0.0).flip(dims=[1])
107
+ pad_batch_input_ids = pad_batch_input_ids[:,-model.config.multimodal_max_length:]
108
+ pad_batch_attention_mask = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in batch_attention_mask],batch_first=True, padding_value=False).flip(dims=[1])
109
+ pad_batch_attention_mask = pad_batch_attention_mask[:,-model.config.multimodal_max_length:]
110
+ pad_batch_pixel_values = [item for sublist in batch_pixel_values for item in sublist]
111
+
112
+ # generate output
113
+ with torch.inference_mode():
114
+ gen_kwargs = dict(
115
+ max_new_tokens=1024,
116
+ do_sample=False,
117
+ top_p=None,
118
+ top_k=None,
119
+ temperature=None,
120
+ repetition_penalty=None,
121
+ eos_token_id=model.generation_config.eos_token_id,
122
+ pad_token_id=text_tokenizer.pad_token_id,
123
+ use_cache=True
124
+ )
125
+ output_ids = model.generate(pad_batch_input_ids, pixel_values=pad_batch_pixel_values, attention_mask=pad_batch_attention_mask, **gen_kwargs)
126
+
127
+ for i in range(len(batch_input_ids)):
128
+ output = text_tokenizer.decode(output_ids[i], skip_special_tokens=True)
129
+ print(f'Output_{i}:\n{output}')
130
+ ```
131
+ </details>
132
+
133
  ## Citation
134
  If you find Ovis useful, please cite the paper
135
  ```