Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import PIL.Image | |
import torch | |
from diffusers import UniDiffuserPipeline | |
class Model: | |
def __init__(self): | |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if self.device.type == "cuda": | |
self.pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) | |
self.pipe.to(self.device) | |
else: | |
self.pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") | |
def run( | |
self, | |
mode: str, | |
prompt: str, | |
image: PIL.Image.Image | None, | |
seed: int = 0, | |
num_steps: int = 20, | |
guidance_scale: float = 8.0, | |
) -> tuple[PIL.Image.Image | None, str]: | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
if mode == "t2i": | |
self.pipe.set_text_to_image_mode() | |
sample = self.pipe( | |
prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator | |
) | |
return sample.images[0], "" | |
elif mode == "i2t": | |
self.pipe.set_image_to_text_mode() | |
sample = self.pipe( | |
image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator | |
) | |
return None, sample.text[0] | |
elif mode == "joint": | |
self.pipe.set_joint_mode() | |
sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) | |
return sample.images[0], sample.text[0] | |
elif mode == "i": | |
self.pipe.set_image_mode() | |
sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) | |
return sample.images[0], "" | |
elif mode == "t": | |
self.pipe.set_text_mode() | |
sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) | |
return None, sample.text[0] | |
elif mode == "i2t2i": | |
self.pipe.set_image_to_text_mode() | |
sample = self.pipe( | |
image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator | |
) | |
self.pipe.set_text_to_image_mode() | |
sample = self.pipe( | |
prompt=sample.text[0], | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
) | |
return sample.images[0], "" | |
elif mode == "t2i2t": | |
self.pipe.set_text_to_image_mode() | |
sample = self.pipe( | |
prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator | |
) | |
self.pipe.set_image_to_text_mode() | |
sample = self.pipe( | |
image=sample.images[0], | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
) | |
return None, sample.text[0] | |
else: | |
raise ValueError | |