File size: 3,821 Bytes
dada74e
 
 
 
 
 
 
 
 
 
51a2c42
8b5d788
f1c062a
2640fff
 
dada74e
 
 
 
8b5d788
 
dada74e
2307701
15a2a80
2307701
15a2a80
 
 
 
 
 
 
 
 
 
dada74e
a331dda
dada74e
51a2c42
a331dda
995325f
dada74e
e17fc09
2307701
 
 
dada74e
 
 
 
 
 
995325f
 
dada74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48670f6
dada74e
 
 
 
 
 
 
 
e5da114
 
 
 
b17e7eb
dada74e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Reference: https://huggingface.co/spaces/FoundationVision/LlamaGen/blob/main/app.py
from PIL import Image
import gradio as gr
from imagenet_classes import imagenet_idx2classname
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import time
import demo_util
import os
import spaces
from huggingface_hub import hf_hub_download

os.system("pip3 install -U numpy")

model2ckpt = {
    "TiTok-L-32": ("tokenizer_titok_l32.bin", "generator_titok_l32.bin"),
}

hf_hub_download(repo_id="fun-research/TiTok", filename="tokenizer_titok_l32.bin", local_dir="./")
hf_hub_download(repo_id="fun-research/TiTok", filename="generator_titok_l32.bin", local_dir="./")

# @spaces.GPU
def load_model():
    device = "cuda" #if torch.cuda.is_available() else "cpu"
    config = demo_util.get_config("configs/titok_l32.yaml")
    print(config)
    titok_tokenizer = demo_util.get_titok_tokenizer(config)
    print(titok_tokenizer)
    titok_generator = demo_util.get_titok_generator(config)
    print(titok_generator)

    titok_tokenizer = titok_tokenizer.to(device)
    titok_generator = titok_generator.to(device)
    return titok_tokenizer, titok_generator

titok_tokenizer, titok_generator = load_model()

@spaces.GPU
def demo_infer(
               guidance_scale, randomize_temperature, num_sample_steps,
               class_label, seed):
    device = "cuda"
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = titok_tokenizer #.to(device)
    generator = titok_generator #.to(device)
    n = 4
    class_labels = [class_label for _ in range(n)]
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    t1 = time.time()
    generated_image = demo_util.sample_fn(
        generator=generator,
        tokenizer=tokenizer,
        labels=class_labels,
        guidance_scale=guidance_scale,
        randomize_temperature=randomize_temperature,
        num_sample_steps=num_sample_steps,
        device=device
    )
    sampling_time = time.time() - t1
    print(f"generation takes about {sampling_time:.2f} seconds.")    
    samples = [Image.fromarray(sample) for sample in generated_image]
    return samples

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center'>An Image is Worth 32 Tokens for Reconstruction and Generation</h1>")

    with gr.Tabs():
        with gr.TabItem('Generate'):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        i1k_class = gr.Dropdown(
                            list(imagenet_idx2classname.values()),
                            value='Eskimo dog, husky',
                            type="index", label='ImageNet-1K Class'
                        )
                    guidance_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=3.5, label='Classifier-free Guidance Scale')
                    randomize_temperature = gr.Slider(minimum=0., maximum=10.0, step=0.1, value=1.0, label='randomize_temperature')
                    num_sample_steps = gr.Slider(minimum=1, maximum=32, step=1, value=8, label='num_sample_steps')
                    seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
                    button = gr.Button("Generate", variant="primary")
                with gr.Column():
                    output = gr.Gallery(label='Generated Images',
                                        columns=4,
                                        rows=1,
                                        height=256, object_fit="scale-down")
                    button.click(demo_infer, inputs=[
                        guidance_scale, randomize_temperature, num_sample_steps,
                        i1k_class, seed],
                        outputs=[output])
    demo.queue()
    demo.launch(debug=True)