gvecchio commited on
Commit
d414290
1 Parent(s): e480f7f

Update pipeline.py

Browse files
Files changed (2) hide show
  1. pipeline.py +5 -6
  2. prompt_encoder/encoder.py +30 -11
pipeline.py CHANGED
@@ -62,7 +62,8 @@ def postprocess(
62
  output_type = "np"
63
 
64
  image = image.detach().cpu()
65
-
 
66
  if output_type == "latent":
67
  return image
68
 
@@ -412,9 +413,7 @@ class MatForgerPipeline(DiffusionPipeline, FromSingleFileMixin):
412
  raise ValueError(
413
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
414
  )
415
- elif prompt is not None and (
416
- not isinstance(prompt, str) and not isinstance(prompt, list)
417
- ):
418
  raise ValueError(
419
  f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
420
  )
@@ -671,8 +670,8 @@ class MatForgerPipeline(DiffusionPipeline, FromSingleFileMixin):
671
  ] = None,
672
  height: Optional[int] = None,
673
  width: Optional[int] = None,
674
- tileable: bool = True,
675
- patched: bool = True,
676
  num_inference_steps: int = 50,
677
  timesteps: List[int] = None,
678
  guidance_scale: float = 7.5,
 
62
  output_type = "np"
63
 
64
  image = image.detach().cpu()
65
+ image = image.to(torch.float32)
66
+
67
  if output_type == "latent":
68
  return image
69
 
 
413
  raise ValueError(
414
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
415
  )
416
+ elif prompt is not None and (not isinstance(prompt, (str, list, Image.Image))):
 
 
417
  raise ValueError(
418
  f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
419
  )
 
670
  ] = None,
671
  height: Optional[int] = None,
672
  width: Optional[int] = None,
673
+ tileable: bool = False,
674
+ patched: bool = False,
675
  num_inference_steps: int = 50,
676
  timesteps: List[int] = None,
677
  guidance_scale: float = 7.5,
prompt_encoder/encoder.py CHANGED
@@ -1,6 +1,13 @@
1
- from typing import List, Optional
2
 
 
 
 
 
 
 
3
  from diffusers.configuration_utils import ConfigMixin
 
4
  from diffusers.models.modeling_utils import ModelMixin
5
  from PIL import Image
6
  from transformers import (
@@ -10,6 +17,15 @@ from transformers import (
10
  CLIPVisionModelWithProjection,
11
  )
12
 
 
 
 
 
 
 
 
 
 
13
 
14
  class BasePromptEncoder(ModelMixin, ConfigMixin):
15
  def __init__(self):
@@ -59,16 +75,19 @@ class MaterialPromptEncoder(BasePromptEncoder):
59
  self,
60
  prompt,
61
  ):
62
- dtype = type(prompt)
63
- if dtype == list:
64
- dtype = type(prompt[0])
65
-
66
- if dtype == str:
67
- return self.encode_text(prompt)
68
- elif dtype == Image.Image:
69
- return self.encode_image(prompt)
70
- else:
71
- raise NotImplementedError
 
 
 
72
 
73
  def forward(
74
  self,
 
1
+ from typing import List, Union, get_args
2
 
3
+ import PIL
4
+ import PIL.Jpeg2KImagePlugin
5
+ import PIL.JpegImagePlugin
6
+ import PIL.PngImagePlugin
7
+ import PIL.TiffImagePlugin
8
+ import torch
9
  from diffusers.configuration_utils import ConfigMixin
10
+ from diffusers.image_processor import PipelineImageInput
11
  from diffusers.models.modeling_utils import ModelMixin
12
  from PIL import Image
13
  from transformers import (
 
17
  CLIPVisionModelWithProjection,
18
  )
19
 
20
+ StrInput = Union[str, List[str]]
21
+
22
+ ImageInput = Union[
23
+ PIL.JpegImagePlugin.JpegImageFile,
24
+ PIL.Jpeg2KImagePlugin.Jpeg2KImageFile,
25
+ PIL.PngImagePlugin.PngImageFile,
26
+ PIL.TiffImagePlugin.TiffImageFile,
27
+ ]
28
+
29
 
30
  class BasePromptEncoder(ModelMixin, ConfigMixin):
31
  def __init__(self):
 
75
  self,
76
  prompt,
77
  ):
78
+ if type(prompt) != list:
79
+ prompt = [prompt]
80
+
81
+ embs = []
82
+ for prompt in prompt:
83
+ if isinstance(prompt, str):
84
+ embs.append(self.encode_text(prompt))
85
+ elif type(prompt, get_args(ImageInput)):
86
+ embs.append(self.encode_image(prompt))
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ return torch.cat(embs, dim=0)
91
 
92
  def forward(
93
  self,