ryanzhangfan commited on
Commit
cb9cf0f
1 Parent(s): c4ec650

Upload 2 files

Browse files
Files changed (2) hide show
  1. model_index.json +1 -1
  2. pipeline_emu2_gen.py +28 -12
model_index.json CHANGED
@@ -6,7 +6,7 @@
6
  "CLIPImageProcessor"
7
  ],
8
  "multimodal_encoder": [
9
- "transformers_modules.modeling_emu",
10
  "EmuForCausalLM"
11
  ],
12
  "safety_checker": [
 
6
  "CLIPImageProcessor"
7
  ],
8
  "multimodal_encoder": [
9
+ "transformers_modules.multimodal_encoder.modeling_emu",
10
  "EmuForCausalLM"
11
  ],
12
  "safety_checker": [
pipeline_emu2_gen.py CHANGED
@@ -8,14 +8,14 @@
8
  # Email : zhangfan@baai.ac.cn
9
  # Institute : Beijing Academy of Artificial Intelligence (BAAI)
10
  # Create On : 2023-12-19 10:45
11
- # Last Modified : 2023-12-19 14:01
12
- # File Name : pipeline.py
13
  # Description :
14
  #
15
  # ===========================================================================================
16
 
17
  from dataclasses import dataclass
18
- from typing import List, Optional, Union
19
 
20
  from PIL import Image
21
  import numpy as np
@@ -38,8 +38,8 @@ DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]"
38
 
39
  @dataclass
40
  class EmuVisualGenerationPipelineOutput(BaseOutput):
41
- images: Union[List[Image.Image], np.ndarray]
42
- nsfw_content_detected: Optional[List[bool]]
43
 
44
 
45
  class EmuVisualGenerationPipeline(DiffusionPipeline):
@@ -76,7 +76,7 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
76
  TF.Normalize(mean=eva_mean, std=eva_std),
77
  ])
78
 
79
- self.negative_prompt = None
80
 
81
  def device(self, module):
82
  return next(module.parameters()).device
@@ -166,7 +166,10 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
166
 
167
  # 7. Convert to PIL
168
  images = self.numpy_to_pil(images)
169
- return EmuVisualGenerationPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
 
 
 
170
 
171
  def _prepare_and_encode_inputs(
172
  self,
@@ -177,11 +180,14 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
177
  device = self.device(self.multimodal_encoder.model.visual)
178
  dtype = self.dtype(self.multimodal_encoder.model.visual)
179
 
 
180
  text_prompt, image_prompt = "", []
181
  for x in inputs:
182
  if isinstance(x, str):
 
183
  text_prompt += x
184
  else:
 
185
  text_prompt += placeholder
186
  image_prompt.append(self.transform(x))
187
 
@@ -191,11 +197,21 @@ class EmuVisualGenerationPipeline(DiffusionPipeline):
191
  image_prompt = torch.stack(image_prompt)
192
  image_prompt = image_prompt.type(dtype).to(device)
193
 
194
- prompt = self.multimodal_encoder.generate_image(text=[text_prompt], image=image_prompt, tokenizer=self.tokenizer)
195
- if do_classifier_free_guidance:
196
- if self.negative_prompt is None:
197
- self.negative_prompt = self.multimodal_encoder.generate_image(text=[""], tokenizer=self.tokenizer)
198
- prompt = torch.cat([prompt, self.negative_prompt], dim=0)
 
 
 
 
 
 
 
 
 
 
199
 
200
  return prompt
201
 
 
8
  # Email : zhangfan@baai.ac.cn
9
  # Institute : Beijing Academy of Artificial Intelligence (BAAI)
10
  # Create On : 2023-12-19 10:45
11
+ # Last Modified : 2023-12-25 07:59
12
+ # File Name : pipeline_emu2_gen.py
13
  # Description :
14
  #
15
  # ===========================================================================================
16
 
17
  from dataclasses import dataclass
18
+ from typing import List, Optional
19
 
20
  from PIL import Image
21
  import numpy as np
 
38
 
39
  @dataclass
40
  class EmuVisualGenerationPipelineOutput(BaseOutput):
41
+ image: Image.Image
42
+ nsfw_content_detected: Optional[bool]
43
 
44
 
45
  class EmuVisualGenerationPipeline(DiffusionPipeline):
 
76
  TF.Normalize(mean=eva_mean, std=eva_std),
77
  ])
78
 
79
+ self.negative_prompt = {}
80
 
81
  def device(self, module):
82
  return next(module.parameters()).device
 
166
 
167
  # 7. Convert to PIL
168
  images = self.numpy_to_pil(images)
169
+ return EmuVisualGenerationPipelineOutput(
170
+ image=images[0],
171
+ nsfw_content_detected=None if has_nsfw_concept is None else has_nsfw_concept[0],
172
+ )
173
 
174
  def _prepare_and_encode_inputs(
175
  self,
 
180
  device = self.device(self.multimodal_encoder.model.visual)
181
  dtype = self.dtype(self.multimodal_encoder.model.visual)
182
 
183
+ has_image, has_text = False, False
184
  text_prompt, image_prompt = "", []
185
  for x in inputs:
186
  if isinstance(x, str):
187
+ has_text = True
188
  text_prompt += x
189
  else:
190
+ has_image = True
191
  text_prompt += placeholder
192
  image_prompt.append(self.transform(x))
193
 
 
197
  image_prompt = torch.stack(image_prompt)
198
  image_prompt = image_prompt.type(dtype).to(device)
199
 
200
+ if has_image and not has_text:
201
+ prompt = self.multimodal_encoder.model.encode_image(image=image_prompt)
202
+ if do_classifier_free_guidance:
203
+ key = "[NULL_IMAGE]"
204
+ if key not in self.negative_prompt:
205
+ negative_image = torch.zeros_like(image_prompt)
206
+ self.negative_prompt[key] = self.multimodal_encoder.model.encode_image(image=negative_image)
207
+ prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
208
+ else:
209
+ prompt = self.multimodal_encoder.generate_image(text=[text_prompt], image=image_prompt, tokenizer=self.tokenizer)
210
+ if do_classifier_free_guidance:
211
+ key = ""
212
+ if key not in self.negative_prompt:
213
+ self.negative_prompt[key] = self.multimodal_encoder.generate_image(text=[""], tokenizer=self.tokenizer)
214
+ prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
215
 
216
  return prompt
217