File size: 3,143 Bytes
78c2594
 
eeba8b2
78c2594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeba8b2
 
 
 
 
 
 
78c2594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeba8b2
 
 
 
 
 
 
 
78c2594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeba8b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from PIL import Image
import requests

from tld.denoiser import Denoiser
from tld.diffusion import DiffusionGenerator

from diffusers import AutoencoderKL, AutoencoderTiny
from tqdm import tqdm
import clip
import torch
import numpy as np
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
to_pil = transforms.ToPILImage()


def download_file(url, filename):

    with requests.get(url, stream=True) as r:
        r.raise_for_status() 
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192): 
                f.write(chunk)

@torch.no_grad()
def encode_text(label, model):
    text_tokens = clip.tokenize(label, truncate=True).to(device)
    text_encoding = model.encode_text(text_tokens)
    return text_encoding.cpu()

def generate_image_from_text(prompt, class_guidance=6, seed=11, num_imgs=1, img_size = 32):

    n_iter = 15
    nrow = int(np.sqrt(num_imgs))

    cur_prompts = [prompt]*num_imgs
    labels = encode_text(cur_prompts, clip_model)
    out, out_latent = diffuser.generate(labels=labels,
                                        num_imgs=num_imgs,
                                        class_guidance=class_guidance,
                                        seed=seed,
                                        n_iter=n_iter,
                                        exponent=1,
                                        scale_factor=8,
                                        sharp_f=0,
                                        bright_f=0
                                            )

    out = to_pil((vutils.make_grid((out+1)/2, nrow=nrow, padding=4)).float().clip(0, 1))

    out.save(f'{prompt}_cfg:{class_guidance}_seed:{seed}.png')

    print("Images Generated and Saved. They will shortly output below.")
    return out

###config:
vae_scale_factor = 8
img_size = 32
model_dtype = torch.float32

file_url = "https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth"
local_filename = "state_dict_378000.pth"
download_file(file_url, local_filename)


denoiser = Denoiser(image_size=32, noise_embed_dims=256, patch_size=2,
                 embed_dim=768, dropout=0, n_layers=12)


state_dict = torch.load('state_dict_378000.pth', map_location=torch.device('cpu'))

denoiser = denoiser.to(model_dtype)
denoiser.load_state_dict(state_dict)
denoiser = denoiser.to(device)

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix",
                                    torch_dtype=model_dtype).to(device)

clip_model, preprocess = clip.load("ViT-L/14")
clip_model = clip_model.to(device)

diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype)

# Define the Gradio interface
iface = gr.Interface(
    fn=generate_image_from_text,  # The function to generate the image
    inputs=["text", "slider"],
    outputs="image",
    title="Text-to-Image Generator",
    description="Enter a text prompt to generate an image."
)

# Launch the app
iface.launch()