File size: 3,159 Bytes
c1b6d44
 
5a6a815
c1b6d44
5a6a815
b5f510e
c51e381
5a6a815
df30d9a
 
5a6a815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e54ddd1
 
 
5a6a815
 
 
 
e54ddd1
5a6a815
e54ddd1
 
5a6a815
e54ddd1
 
 
5a6a815
 
 
 
e54ddd1
 
 
 
 
 
5a6a815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr

import torch

from diffusers import UniDiffuserPipeline


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_id = "thu-ml/unidiffuser-v1"
# model_id = "thu-ml/unidiffuser-v0"
pipeline = UniDiffuserPipeline.from_pretrained(
    model_id,
)
pipeline.to(device)


def convert_to_none(s):
    if s:
        return s
    else:
        return None


def set_mode(mode):
    if mode == "joint":
        pipeline.set_joint_mode()
    elif mode == "text2img":
        pipeline.set_text_to_image_mode()
    elif mode == "img2text":
        pipeline.set_image_text_mode()
    elif mode == "text":
        pipeline.set_text_mode()
    elif mode == "img":
        pipeline.set_image_mode()


def sample(mode, prompt, image, num_inference_steps, guidance_scale, seed):
    set_mode(mode)
    prompt = convert_to_none(prompt)
    image = convert_to_none(image)
    generator = torch.Generator(device=device).manual_seed(seed)
    output_sample = pipeline(
        prompt=prompt,
        image=image,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
    )
    sample_image = None
    sample_text = ""
    if output_sample.images is not None:
        sample_image = output_sample.images[0]
    if output_sample.text is not None:
        sample_text = output_sample.text[0]
    return sample_image, sample_text
    

iface = gr.Interface(
    fn=sample,
    inputs=[
        gr.Textbox(value="", label="Generation Task"),
        gr.Textbox(value="", label="Conditioning prompt"),
        gr.Image(value=None, label="Conditioning image", type="pil"),
        gr.Number(value=20, label="Num Inference Steps", precision=0),
        gr.Number(value=8.0, label="Guidance Scale"),
        gr.Number(value=0, label="Seed", precision=0),
    ],
    outputs=[
        gr.Image(label="Sample image"),
        gr.Textbox(label="Sample text"),
    ],
)
iface.launch()

# from unidiffuser.sample_v0 import sample
# from unidiffuser.sample_v0_test import sample
# from unidiffuser.sample_v1 import sample
# from unidiffuser.sample_v1_test import sample


# def predict(mode, prompt, image, sample_steps, guidance_scale, seed):
#     output_images, output_text = sample(
#         mode, prompt, image, sample_steps=sample_steps, scale=guidance_scale, seed=seed,
#     )
#     sample_image = None
#     sample_text = ""
#     if output_images is not None:
#         sample_image = output_images[0]
#     if output_text is not None:
#         sample_text = output_text[0]
#     return sample_image, sample_text


# iface = gr.Interface(
#     fn=predict,
#     inputs=[
#         gr.Textbox(value="", label="Generation Task"),
#         gr.Textbox(value="", label="Conditioning prompt"),
#         gr.Image(value=None, label="Conditioning image", type="filepath"),
#         gr.Number(value=50, label="Num Inference Steps", precision=0),
#         gr.Number(value=7.0, label="Guidance Scale"),
#         gr.Number(value=1234, label="Seed", precision=0),
#     ],
#     outputs=[
#         gr.Image(label="Sample image"),
#         gr.Textbox(label="Sample text"),
#     ],
# )
# iface.launch()