miccull's picture
no more cuda
9ca1ed9
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision
import clip
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
model, preprocess = clip.load(model_name)
model.to(DEVICE).eval()
resolution = model.visual.input_resolution
resizer = torchvision.transforms.Resize(size=(resolution, resolution))
def create_rgb_tensor(color):
"""color is e.g. [1,0,0]"""
return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))
def encode_color(color):
"""color is e.g. [1,0,0]"""
rgb = create_rgb_tensor(color)
return model.encode_image( resizer(rgb) )
def encode_text(text):
tokenized_text = clip.tokenize(text).to(DEVICE)
return model.encode_text(tokenized_text)
class RGBModel(torch.nn.Module):
def __init__(self, device):
# Call nn.Module.__init__() to instantiate typical torch.nn.Module stuff
super(RGBModel, self).__init__()
self.color = torch.nn.Parameter(torch.ones((1, 3, 1, 1), device=device) / 2)
def forward(self):
# Clamp numbers to the closed interval [0,1]
self.color.data = self.color.data.clamp(0,1)
return self.color
text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')
steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Training Steps")
lr_input = gr.inputs.Number(default=0.06, label="Adam Optimizer Learning Rate")
decay_input = gr.inputs.Number(default=0.01, label="Adam Optimizer Weight Decay")
def gradio_fn(text_prompt, adam_learning_rate, adam_weight_decay, n_iterations=50):
rgb_model = RGBModel(device=DEVICE)
opt = torch.optim.AdamW([rgb_model()], lr=adam_learning_rate, weight_decay=adam_weight_decay)
with torch.no_grad():
tokenized_text = clip.tokenize(text_prompt).to(DEVICE)
target_embedding = model.encode_text(tokenized_text).detach().clone()
def training_step():
opt.zero_grad()
color = rgb_model()
color_img = resizer(color)
image_embedding = model.encode_image(color_img)
loss = -1 * torch.cosine_similarity(target_embedding, image_embedding, dim=-1)
loss.backward()
opt.step()
steps = []
steps.append(rgb_model().cpu().detach().numpy())
for iteration in range(n_iterations):
training_step()
steps.append(rgb_model().cpu().detach().numpy())
steps = np.stack([steps])
img_train = Image.fromarray((steps[:,:,0,:,0,0] * 255).astype(np.uint8)).resize((400, 100), 0)
return img_train
iface = gr.Interface( fn=gradio_fn, inputs=[text_input, lr_input, decay_input, steps_input], outputs="image")
iface.launch()