Update pipeline.py
Browse files- pipeline.py +5 -6
- prompt_encoder/encoder.py +30 -11
pipeline.py
CHANGED
@@ -62,7 +62,8 @@ def postprocess(
|
|
62 |
output_type = "np"
|
63 |
|
64 |
image = image.detach().cpu()
|
65 |
-
|
|
|
66 |
if output_type == "latent":
|
67 |
return image
|
68 |
|
@@ -412,9 +413,7 @@ class MatForgerPipeline(DiffusionPipeline, FromSingleFileMixin):
|
|
412 |
raise ValueError(
|
413 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
414 |
)
|
415 |
-
elif prompt is not None and (
|
416 |
-
not isinstance(prompt, str) and not isinstance(prompt, list)
|
417 |
-
):
|
418 |
raise ValueError(
|
419 |
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
420 |
)
|
@@ -671,8 +670,8 @@ class MatForgerPipeline(DiffusionPipeline, FromSingleFileMixin):
|
|
671 |
] = None,
|
672 |
height: Optional[int] = None,
|
673 |
width: Optional[int] = None,
|
674 |
-
tileable: bool =
|
675 |
-
patched: bool =
|
676 |
num_inference_steps: int = 50,
|
677 |
timesteps: List[int] = None,
|
678 |
guidance_scale: float = 7.5,
|
|
|
62 |
output_type = "np"
|
63 |
|
64 |
image = image.detach().cpu()
|
65 |
+
image = image.to(torch.float32)
|
66 |
+
|
67 |
if output_type == "latent":
|
68 |
return image
|
69 |
|
|
|
413 |
raise ValueError(
|
414 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
415 |
)
|
416 |
+
elif prompt is not None and (not isinstance(prompt, (str, list, Image.Image))):
|
|
|
|
|
417 |
raise ValueError(
|
418 |
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
419 |
)
|
|
|
670 |
] = None,
|
671 |
height: Optional[int] = None,
|
672 |
width: Optional[int] = None,
|
673 |
+
tileable: bool = False,
|
674 |
+
patched: bool = False,
|
675 |
num_inference_steps: int = 50,
|
676 |
timesteps: List[int] = None,
|
677 |
guidance_scale: float = 7.5,
|
prompt_encoder/encoder.py
CHANGED
@@ -1,6 +1,13 @@
|
|
1 |
-
from typing import List,
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from diffusers.configuration_utils import ConfigMixin
|
|
|
4 |
from diffusers.models.modeling_utils import ModelMixin
|
5 |
from PIL import Image
|
6 |
from transformers import (
|
@@ -10,6 +17,15 @@ from transformers import (
|
|
10 |
CLIPVisionModelWithProjection,
|
11 |
)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class BasePromptEncoder(ModelMixin, ConfigMixin):
|
15 |
def __init__(self):
|
@@ -59,16 +75,19 @@ class MaterialPromptEncoder(BasePromptEncoder):
|
|
59 |
self,
|
60 |
prompt,
|
61 |
):
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
72 |
|
73 |
def forward(
|
74 |
self,
|
|
|
1 |
+
from typing import List, Union, get_args
|
2 |
|
3 |
+
import PIL
|
4 |
+
import PIL.Jpeg2KImagePlugin
|
5 |
+
import PIL.JpegImagePlugin
|
6 |
+
import PIL.PngImagePlugin
|
7 |
+
import PIL.TiffImagePlugin
|
8 |
+
import torch
|
9 |
from diffusers.configuration_utils import ConfigMixin
|
10 |
+
from diffusers.image_processor import PipelineImageInput
|
11 |
from diffusers.models.modeling_utils import ModelMixin
|
12 |
from PIL import Image
|
13 |
from transformers import (
|
|
|
17 |
CLIPVisionModelWithProjection,
|
18 |
)
|
19 |
|
20 |
+
StrInput = Union[str, List[str]]
|
21 |
+
|
22 |
+
ImageInput = Union[
|
23 |
+
PIL.JpegImagePlugin.JpegImageFile,
|
24 |
+
PIL.Jpeg2KImagePlugin.Jpeg2KImageFile,
|
25 |
+
PIL.PngImagePlugin.PngImageFile,
|
26 |
+
PIL.TiffImagePlugin.TiffImageFile,
|
27 |
+
]
|
28 |
+
|
29 |
|
30 |
class BasePromptEncoder(ModelMixin, ConfigMixin):
|
31 |
def __init__(self):
|
|
|
75 |
self,
|
76 |
prompt,
|
77 |
):
|
78 |
+
if type(prompt) != list:
|
79 |
+
prompt = [prompt]
|
80 |
+
|
81 |
+
embs = []
|
82 |
+
for prompt in prompt:
|
83 |
+
if isinstance(prompt, str):
|
84 |
+
embs.append(self.encode_text(prompt))
|
85 |
+
elif type(prompt, get_args(ImageInput)):
|
86 |
+
embs.append(self.encode_image(prompt))
|
87 |
+
else:
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
return torch.cat(embs, dim=0)
|
91 |
|
92 |
def forward(
|
93 |
self,
|