File size: 2,913 Bytes
ffd1a8e
 
 
 
 
 
 
686a471
ffd1a8e
 
cbc8ee5
ffd1a8e
 
cbc8ee5
 
ffd1a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
cbc8ee5
ffd1a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbc8ee5
ffd1a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import LMSDiscreteScheduler
import torch
from tqdm.auto import tqdm
from PIL import Image
import gradio as gr
#from IPython.display import display

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).float()

# Here we use a different VAE to the original release, which has been fine-tuned for more steps
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).float()
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).float()

beta_start,beta_end = 0.00085,0.012
height = 512
width = 512
num_inference_steps = 70
guidance_scale = 7.5
batch_size = 1
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)

#prompt = ["a photograph of an astronaut riding a horse"]

def text_enc(prompts, maxlen=None):
    if maxlen is None: maxlen = tokenizer.model_max_length
    inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
    return text_encoder(inp.input_ids.float())[0]

def do_both(prompts):
    def mk_img(t):
        image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
        return Image.fromarray((image*255).round().astype("uint8"))
      
    def mk_samples(prompts, g=7.5, seed=100, steps=70):
        bs = len(prompts)
        text = text_enc(prompts)
        uncond = text_enc([""] * bs, text.shape[1])
        emb = torch.cat([uncond, text])
        if seed: torch.manual_seed(seed)
    
        latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
        scheduler.set_timesteps(steps)
        latents = latents.float() * scheduler.init_noise_sigma
    
        for i,ts in enumerate(tqdm(scheduler.timesteps)):
            inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
            with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
            pred = u + g*(t-u)
            latents = scheduler.step(pred, ts, latents).prev_sample
    
        with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
    images = mk_samples([prompts])
    for img in images: return(mk_img(img))
            
# do_both(prompt)  
# images = mk_samples(prompt)
#iface = gr.Interface(fn=do_both, inputs=gr.inputs.Textbox(lines=2, label="Enter text prompt"), outputs=gr.outputs.Image(type="numpy", label="Generated Image")).launch()
gr.Interface(do_both, gr.Text(), gr.Image(), title = 'Stable Diffusion model from scratch').launch(share = True, debug = True)
# for img in images: display(mk_img(img))