File size: 3,171 Bytes
639c25d
 
 
 
2881ba6
639c25d
 
 
 
2149360
 
 
2881ba6
639c25d
2149360
2881ba6
 
 
 
 
 
 
 
 
 
 
2149360
2881ba6
2149360
 
 
 
 
2881ba6
2149360
 
 
2881ba6
2149360
2881ba6
2149360
2881ba6
2149360
2881ba6
2149360
 
 
2881ba6
2149360
2881ba6
2149360
2881ba6
2149360
 
 
2881ba6
2149360
 
 
 
 
 
 
 
2881ba6
2149360
 
 
2881ba6
2149360
 
 
 
 
 
2881ba6
639c25d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

import PIL.Image
import torch
from diffusers import UniDiffuserPipeline


class Model:
    def __init__(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if self.device.type == "cuda":
            self.pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
            self.pipe.to(self.device)
        else:
            self.pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")

    def run(
        self,
        mode: str,
        prompt: str,
        image: PIL.Image.Image | None,
        seed: int = 0,
        num_steps: int = 20,
        guidance_scale: float = 8.0,
    ) -> tuple[PIL.Image.Image | None, str]:
        generator = torch.Generator(device=self.device).manual_seed(seed)
        if mode == "t2i":
            self.pipe.set_text_to_image_mode()
            sample = self.pipe(
                prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
            )
            return sample.images[0], ""
        elif mode == "i2t":
            self.pipe.set_image_to_text_mode()
            sample = self.pipe(
                image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
            )
            return None, sample.text[0]
        elif mode == "joint":
            self.pipe.set_joint_mode()
            sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
            return sample.images[0], sample.text[0]
        elif mode == "i":
            self.pipe.set_image_mode()
            sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
            return sample.images[0], ""
        elif mode == "t":
            self.pipe.set_text_mode()
            sample = self.pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
            return None, sample.text[0]
        elif mode == "i2t2i":
            self.pipe.set_image_to_text_mode()
            sample = self.pipe(
                image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
            )
            self.pipe.set_text_to_image_mode()
            sample = self.pipe(
                prompt=sample.text[0],
                num_inference_steps=num_steps,
                guidance_scale=guidance_scale,
                generator=generator,
            )
            return sample.images[0], ""
        elif mode == "t2i2t":
            self.pipe.set_text_to_image_mode()
            sample = self.pipe(
                prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator
            )
            self.pipe.set_image_to_text_mode()
            sample = self.pipe(
                image=sample.images[0],
                num_inference_steps=num_steps,
                guidance_scale=guidance_scale,
                generator=generator,
            )
            return None, sample.text[0]
        else:
            raise ValueError