from __future__ import annotations import PIL.Image import torch from diffusers import UniDiffuserPipeline class Model: def __init__(self): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') if self.device.type == 'cuda': self.pipe = UniDiffuserPipeline.from_pretrained( 'thu-ml/unidiffuser-v1', torch_dtype=torch.float16) self.pipe.to(self.device) else: self.pipe = UniDiffuserPipeline.from_pretrained( 'thu-ml/unidiffuser-v1') def run( self, mode: str, prompt: str, image: PIL.Image.Image | None, seed: int = 0, num_steps: int = 20, guidance_scale: float = 8.0, ) -> tuple[PIL.Image.Image | None, str]: generator = torch.Generator(device=self.device).manual_seed(seed) if mode == 't2i': self.pipe.set_text_to_image_mode() sample = self.pipe(prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return sample.images[0], '' elif mode == 'i2t': self.pipe.set_image_to_text_mode() sample = self.pipe(image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return None, sample.text[0] elif mode == 'joint': self.pipe.set_joint_mode() sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return sample.images[0], sample.text[0] elif mode == 'i': self.pipe.set_image_mode() sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return sample.images[0], '' elif mode == 't': self.pipe.set_text_mode() sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return None, sample.text[0] elif mode == 'i2t2i': self.pipe.set_image_to_text_mode() sample = self.pipe(image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) self.pipe.set_text_to_image_mode() sample = self.pipe(prompt=sample.text[0], num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return sample.images[0], '' elif mode == 't2i2t': self.pipe.set_text_to_image_mode() sample = self.pipe(prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) self.pipe.set_image_to_text_mode() sample = self.pipe(image=sample.images[0], num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator) return None, sample.text[0] else: raise ValueError