Spaces:
Runtime error
Runtime error
File size: 3,171 Bytes
639c25d 2881ba6 639c25d 2149360 2881ba6 639c25d 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 2149360 2881ba6 639c25d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from __future__ import annotations
import PIL.Image
import torch
from diffusers import UniDiffuserPipeline
class Model:
def __init__(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if self.device.type == "cuda":
self.pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
self.pipe.to(self.device)
else:
self.pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
def run(
self,
mode: str,
prompt: str,
image: PIL.Image.Image | None,
seed: int = 0,
num_steps: int = 20,
guidance_scale: float = 8.0,
) -> tuple[PIL.Image.Image | None, str]:
generator = torch.Generator(device=self.device).manual_seed(seed)
if mode == "t2i":
self.pipe.set_text_to_image_mode()
sample = self.pipe(
prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
)
return sample.images[0], ""
elif mode == "i2t":
self.pipe.set_image_to_text_mode()
sample = self.pipe(
image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
)
return None, sample.text[0]
elif mode == "joint":
self.pipe.set_joint_mode()
sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return sample.images[0], sample.text[0]
elif mode == "i":
self.pipe.set_image_mode()
sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return sample.images[0], ""
elif mode == "t":
self.pipe.set_text_mode()
sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return None, sample.text[0]
elif mode == "i2t2i":
self.pipe.set_image_to_text_mode()
sample = self.pipe(
image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
)
self.pipe.set_text_to_image_mode()
sample = self.pipe(
prompt=sample.text[0],
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
generator=generator,
)
return sample.images[0], ""
elif mode == "t2i2t":
self.pipe.set_text_to_image_mode()
sample = self.pipe(
prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
)
self.pipe.set_image_to_text_mode()
sample = self.pipe(
image=sample.images[0],
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
generator=generator,
)
return None, sample.text[0]
else:
raise ValueError
|