rodrigomasini commited on
Commit
d134c7e
1 Parent(s): 96f7484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -1022,7 +1022,27 @@ from diffusers.utils import (
1022
  logging,
1023
  )
1024
 
1025
- from . import PhotoMakerIDEncoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
 
1027
  PipelineImageInput = Union[
1028
  PIL.Image.Image,
 
1022
  logging,
1023
  )
1024
 
1025
+ class PhotoMakerIDEncoder(CLIPVisionModelWithProjection):
1026
+ def __init__(self):
1027
+ super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT))
1028
+ self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
1029
+ self.fuse_module = FuseModule(2048)
1030
+
1031
+ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
1032
+ b, num_inputs, c, h, w = id_pixel_values.shape
1033
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
1034
+
1035
+ shared_id_embeds = self.vision_model(id_pixel_values)[1]
1036
+ id_embeds = self.visual_projection(shared_id_embeds)
1037
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
1038
+
1039
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
1040
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
1041
+
1042
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
1043
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
1044
+
1045
+ return updated_prompt_embeds
1046
 
1047
  PipelineImageInput = Union[
1048
  PIL.Image.Image,