Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import Dict, Tuple | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torchvision import models, transforms | |
from torchvision.utils import save_image, make_grid | |
import matplotlib.pyplot as plt | |
from matplotlib.animation import FuncAnimation, PillowWriter | |
import numpy as np | |
from IPython.display import HTML | |
from diffusion_utilities import * | |
from PIL import Image as im | |
#openai.api_key = os.getenv('OPENAI_API_KEY') | |
class ContextUnet(nn.Module): | |
def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features | |
super(ContextUnet, self).__init__() | |
# number of input channels, number of intermediate feature maps and number of classes | |
self.in_channels = in_channels | |
self.n_feat = n_feat | |
self.n_cfeat = n_cfeat | |
self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16... | |
# Initialize the initial convolutional layer | |
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) | |
# Initialize the down-sampling path of the U-Net with two levels | |
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8] | |
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4] | |
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) | |
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU()) | |
# Embed the timestep and context labels with a one-layer fully connected neural network | |
self.timeembed1 = EmbedFC(1, 2*n_feat) | |
self.timeembed2 = EmbedFC(1, 1*n_feat) | |
self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat) | |
self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat) | |
# Initialize the up-sampling path of the U-Net with three levels | |
self.up0 = nn.Sequential( | |
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), | |
nn.GroupNorm(8, 2 * n_feat), # normalize | |
nn.ReLU(), | |
) | |
self.up1 = UnetUp(4 * n_feat, n_feat) | |
self.up2 = UnetUp(2 * n_feat, n_feat) | |
# Initialize the final convolutional layers to map to the same number of channels as the input image | |
self.out = nn.Sequential( | |
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0 | |
nn.GroupNorm(8, n_feat), # normalize | |
nn.ReLU(), | |
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input | |
) | |
def forward(self, x, t, c=None): | |
""" | |
x : (batch, n_feat, h, w) : input image | |
t : (batch, n_cfeat) : time step | |
c : (batch, n_classes) : context label | |
""" | |
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on | |
# pass the input image through the initial convolutional layer | |
x = self.init_conv(x) | |
# pass the result through the down-sampling path | |
down1 = self.down1(x) #[10, 256, 8, 8] | |
down2 = self.down2(down1) #[10, 256, 4, 4] | |
# convert the feature maps to a vector and apply an activation | |
hiddenvec = self.to_vec(down2) | |
# mask out context if context_mask == 1 | |
if c is None: | |
c = torch.zeros(x.shape[0], self.n_cfeat).to(x) | |
# embed context and timestep | |
cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1) | |
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) | |
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) | |
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) | |
#print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}") | |
up1 = self.up0(hiddenvec) | |
up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings | |
up3 = self.up2(cemb2*up2 + temb2, down1) | |
out = self.out(torch.cat((up3, x), 1)) | |
return out | |
# hyperparameters | |
# diffusion hyperparameters | |
timesteps = 5000 | |
beta1 = 1e-4 | |
beta2 = 0.02 | |
# network hyperparameters | |
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu')) | |
n_feat = 64 # 64 hidden dimension feature | |
n_cfeat = 5 # context vector is of size 5 | |
height = 16 # 16x16 image | |
save_dir = './weights/' | |
# training hyperparameters | |
batch_size = 1000 | |
n_epoch = 512 | |
lrate=1e-3 | |
# construct DDPM noise schedule | |
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1 | |
a_t = 1 - b_t | |
ab_t = torch.cumsum(a_t.log(), dim=0).exp() | |
ab_t[0] = 1 | |
# construct model | |
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device) | |
# define sampling function for DDIM | |
# removes the noise using ddim | |
def denoise_ddim(x, t, t_prev, pred_noise): | |
ab = ab_t[t] | |
ab_prev = ab_t[t_prev] | |
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise) | |
dir_xt = (1 - ab_prev).sqrt() * pred_noise | |
return x0_pred + dir_xt | |
# load in model weights and set to eval mode | |
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device)) | |
nn_model.eval() | |
print("Loaded in Model without context") | |
# sample quickly using DDIM | |
def sample_ddim(n_sample, n=20): | |
# x_T ~ N(0, 1), sample initial noise | |
samples = torch.randn(n_sample, 3, height, height).to(device) | |
# array to keep track of generated steps for plotting | |
intermediate = [] | |
step_size = timesteps // n | |
for i in range(timesteps, 0, -step_size): | |
print(f'sampling timestep {i:3d}', end='\r') | |
# reshape time tensor | |
t = torch.tensor([i / timesteps])[:, None, None, None].to(device) | |
eps = nn_model(samples, t) # predict noise e_(x_t,t) | |
samples = denoise_ddim(samples, i, i - step_size, eps) | |
intermediate.append(samples.detach().cpu().numpy()) | |
intermediate = np.stack(intermediate) | |
return samples, intermediate | |
# load in model weights and set to eval mode | |
nn_model.load_state_dict(torch.load(f"{save_dir}/ft_context_model_0.pth", map_location=device)) | |
nn_model.eval() | |
print("Loaded in Context Model") | |
# fast sampling algorithm with context | |
def sample_ddim_context(n_sample, context, n=20): | |
# x_T ~ N(0, 1), sample initial noise | |
samples = torch.randn(n_sample, 3, height, height).to(device) | |
# array to keep track of generated steps for plotting | |
intermediate = [] | |
step_size = timesteps // n | |
for i in range(timesteps, 0, -step_size): | |
print(f'sampling timestep {i:3d}', end='\r') | |
# reshape time tensor | |
t = torch.tensor([i / timesteps])[:, None, None, None].to(device) | |
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t) | |
samples = denoise_ddim(samples, i, i - step_size, eps) | |
intermediate.append(samples.detach().cpu().numpy()) | |
intermediate = np.stack(intermediate) | |
return samples, intermediate | |
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse) | |
def denoise_add_noise(x, t, pred_noise, z=None): | |
if z is None: | |
z = torch.randn_like(x) | |
noise = b_t.sqrt()[t] * z | |
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt() | |
return mean + noise | |
# sample using standard algorithm | |
def sample_ddpm(n_sample, context, save_rate=20): | |
# x_T ~ N(0, 1), sample initial noise | |
samples = torch.randn(n_sample, 3, height, height).to(device) | |
# array to keep track of generated steps for plotting | |
intermediate = [] | |
for i in range(timesteps, 0, -1): | |
print(f'sampling timestep {i:3d}', end='\r') | |
# reshape time tensor | |
t = torch.tensor([i / timesteps])[:, None, None, None].to(device) | |
# sample some random noise to inject back in. For i = 1, don't add back in noise | |
z = torch.randn_like(samples) if i > 1 else 0 | |
eps = nn_model(samples, t) # predict noise e_(x_t,t) | |
samples = denoise_add_noise(samples, i, eps, z) | |
if i % save_rate ==0 or i==timesteps or i<8: | |
intermediate.append(samples.detach().cpu().numpy()) | |
intermediate = np.stack(intermediate) | |
return samples, intermediate | |
def sample_ddpm_context(n_sample,timesteps, context, save_rate=20): | |
# x_T ~ N(0, 1), sample initial noise | |
samples = torch.randn(n_sample, 3, height, height).to(device) | |
# array to keep track of generated steps for plotting | |
intermediate = [] | |
for i in range(timesteps, 0, -1): | |
print(f'sampling timestep {i:3d}', end='\r') | |
# reshape time tensor | |
t = torch.tensor([i / timesteps])[:, None, None, None].to(device) | |
# sample some random noise to inject back in. For i = 1, don't add back in noise | |
z = torch.randn_like(samples) if i > 1 else 0 | |
eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t) | |
samples = denoise_add_noise(samples, i, eps, z) | |
if i % save_rate ==0 or i==timesteps or i<8: | |
intermediate.append(samples.detach().cpu().numpy()) | |
intermediate = np.stack(intermediate) | |
return samples, intermediate | |
def greet(input): | |
steps = int(input) | |
image_count = 1; | |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float() | |
# hero, non-hero, food, spell, side-facing | |
one_hot_enc = np.array([1, 0, 0, 0, 0]) | |
shape = (image_count, 5) | |
mtx_2d = np.ones(shape) * one_hot_enc | |
ctx = torch.from_numpy(mtx_2d).to(device=device).float() | |
#samples, intermediate = sample_ddim_context(32, ctx, n=steps) | |
samples, intermediate = sample_ddpm_context(image_count, steps, ctx) | |
#samples, intermediate = sample_ddim(32, n=steps) | |
#ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float() | |
#samples, intermediate = sample_ddim_context(32, ctx, steps) | |
#samples, intermediate = sample_ddpm(steps) | |
#response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1])) | |
#response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1])) | |
#response = im.fromarray(intermediate[24][0][1]).convert("RGB") | |
sx_gen_store = np.moveaxis(intermediate,2,4) | |
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], image_count) | |
response = intermediate.shape; | |
response2 = transform2(transform(nsx_gen_store[-1][0])) | |
# response3 = transform2(transform(nsx_gen_store[-1][1])) | |
# response4 = transform2(transform(nsx_gen_store[-1][2])) | |
# response5 = transform2(transform(nsx_gen_store[-1][3])) | |
# response6 = transform2(transform(nsx_gen_store[-1][4])) | |
# response7 = transform2(transform(nsx_gen_store[-1][5])) | |
# response8 = transform2(transform(nsx_gen_store[-1][6])) | |
# response9 = transform2(transform(nsx_gen_store[-1][7])) | |
# response10 = transform2(transform(nsx_gen_store[-1][8])) | |
# response11 = transform2(transform(nsx_gen_store[-1][9])) | |
# response12 = transform2(transform(nsx_gen_store[-1][10])) | |
# response13 = transform2(transform(nsx_gen_store[-1][11])) | |
# response14 = transform2(transform(nsx_gen_store[-1][12])) | |
# response15 = transform2(transform(nsx_gen_store[-1][13])) | |
# response16 = transform2(transform(nsx_gen_store[-1][14])) | |
# response17 = transform2(transform(nsx_gen_store[-1][15])) | |
# response18 = transform2(transform(nsx_gen_store[-1][16])) | |
# response19 = transform2(transform(nsx_gen_store[-1][17])) | |
# response20 = transform2(transform(nsx_gen_store[-1][18])) | |
# response21 = transform2(transform(nsx_gen_store[-1][19])) | |
# response22 = transform2(transform(nsx_gen_store[-1][20])) | |
# response23 = transform2(transform(nsx_gen_store[-1][21])) | |
# response24 = transform2(transform(nsx_gen_store[-1][22])) | |
# response25 = transform2(transform(nsx_gen_store[-1][23])) | |
# response26 = transform2(transform(nsx_gen_store[-1][24])) | |
# response27 = transform2(transform(nsx_gen_store[-1][25])) | |
# response28 = transform2(transform(nsx_gen_store[-1][26])) | |
# response29 = transform2(transform(nsx_gen_store[-1][27])) | |
# response30= transform2(transform(nsx_gen_store[-1][28])) | |
# response31 = transform2(transform(nsx_gen_store[-1][29])) | |
# response32 = transform2(transform(nsx_gen_store[-1][30])) | |
# response33 = transform2(transform(nsx_gen_store[-1][31])) | |
#response = intermediate.shape; | |
#response2 = transform2(transform(np.moveaxis(intermediate,2,4)[0][0])) | |
#response3 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/2)][0])) | |
#response4 = transform2(transform(np.moveaxis(intermediate,2,4)[int(steps/4)][0])) | |
#response5 = transform2(transform(np.moveaxis(intermediate,2,4)[-1][0])) | |
###return response, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11, response12, response13, response14, response15, response16, response17, response18, response19, response20, response21, response22, response23, response24, response25, response26, response27, response28, response29, response30, response31, response32, response33 | |
return response2 | |
transform2 = transforms.ToPILImage() | |
#iface = gr.Interface(fn=greet, inputs="text", outputs="text") | |
#iface.launch() | |
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Text to find entities", lines=2)], outputs=[gr.HighlightedText(label="Text with entities")], title="NER with dslim/bert-base-NER", description="Find entities using the `dslim/bert-base-NER` model under the hood!", allow_flagging="never", examples=["My name is Andrew and I live in California", "My name is Poli and work at HuggingFace"]) | |
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.outputs.Image(type="pil", width=64, label="Output Image"), gr.outputs.Image(type="pil", width=64, label="Output Image2"), gr.outputs.Image(type="pil", width=64, label="Output Image3"), gr.outputs.Image(type="pil", width=64, label="Output Image4")]) | |
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4")]) | |
###iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Textbox(label="Info"), gr.Image(type="pil", width=64, label="Output Image"), gr.Image(type="pil", width=64, label="Output Image2"), gr.Image(type="pil", width=64, label="Output Image3"), gr.Image(type="pil", width=64, label="Output Image4"), gr.Image(type="pil", width=64, label="Output Image5"), gr.Image(type="pil", width=64, label="Output Image6"), gr.Image(type="pil", width=64, label="Output Image7"), gr.Image(type="pil", width=64, label="Output Image8"), gr.Image(type="pil", width=64, label="Output Image9"), gr.Image(type="pil", width=64, label="Output Image10"), gr.Image(type="pil", width=64, label="Output Image11"), gr.Image(type="pil", width=64, label="Output Image12"), gr.Image(type="pil", width=64, label="Output Image13"), gr.Image(type="pil", width=64, label="Output Image14"), gr.Image(type="pil", width=64, label="Output Image15"), gr.Image(type="pil", width=64, label="Output Image16"), gr.Image(type="pil", width=64, label="Output Image17"), gr.Image(type="pil", width=64, label="Output Image18"), gr.Image(type="pil", width=64, label="Output Image19"), gr.Image(type="pil", width=64, label="Output Image20"), gr.Image(type="pil", width=64, label="Output Image21"), gr.Image(type="pil", width=64, label="Output Image22"), gr.Image(type="pil", width=64, label="Output Image23"), gr.Image(type="pil", width=64, label="Output Image24"), gr.Image(type="pil", width=64, label="Output Image25"), gr.Image(type="pil", width=64, label="Output Image26"), gr.Image(type="pil", width=64, label="Output Image27"), gr.Image(type="pil", width=64, label="Output Image28"), gr.Image(type="pil", width=64, label="Output Image29"), gr.Image(type="pil", width=64, label="Output Image30"), gr.Image(type="pil", width=64, label="Output Image31"), gr.Image(type="pil", width=64, label="Output Image32")]) | |
iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="steps", value=20)], outputs=[gr.Image(type="pil", width=64, label="Output Image")]) | |
#iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Co-Retailing Business")], outputs=[gr.Textbox()]) | |
iface.launch() | |