bryandts's picture
Update app.py
35de619 verified
raw
history blame
4.82 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from sentence_transformers import SentenceTransformer, util
import json
import os
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
# The Generator model
class Generator(nn.Module):
def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
super(Generator, self).__init__()
self.channels = channels
self.noise_dim = noise_dim
self.embed_dim = embed_dim
self.embed_out_dim = embed_out_dim
# Text embedding layers
self.text_embedding = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_out_dim),
nn.BatchNorm1d(1),
nn.LeakyReLU(0.2, inplace=True)
)
# Generator architecture
model = []
model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
model += self._create_layer(512, 256, 4, stride=2, padding=1)
model += self._create_layer(256, 128, 4, stride=2, padding=1)
model += self._create_layer(128, 64, 4, stride=2, padding=1)
model += self._create_layer(64, 32, 4, stride=2, padding=1)
model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
self.model = nn.Sequential(*model)
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
if output:
layers.append(nn.Tanh()) # Tanh activation for the output layer
else:
layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
return layers
def forward(self, noise, text):
# Apply text embedding to the input text
text = self.text_embedding(text)
text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
return self.model(z)
noise_dim = 16
embed_dim = 384
embed_out_dim = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(channels=3, embed_dim=embed_dim, noise_dim=noise_dim, embed_out_dim=embed_out_dim).to(device)
# Path to your .pth file
gen_weight = 'generator_20240421_3.pth'
# Load the weights
weights_gen = torch.load(gen_weight, map_location=torch.device(device))
# Apply the weights to your model
generator.load_state_dict(weights_gen)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
def load_embedding(model):
# Load your model and other components here
with open(os.path.join("descriptions.json"), 'r') as file:
dataset = json.load(file)
classes = [e["text"] for e in dataset]
embeddings_list = {cls: model.encode(cls, convert_to_tensor=True) for cls in classes}
return embeddings_list
def generate_image(caption):
embeddings = load_embedding(model)
noise_dim = 16
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
threshold = 0.40
coeff = 0.89
if sorted_results[0][0] <= threshold:
caption = sorted_results[0][1]
results = [(util.pytorch_cos_sim(model.encode(caption, convert_to_tensor=True), embeddings[cls]).item(), cls) for cls in classes]
sorted_results = sorted(results, key=lambda x: x[0], reverse=True)[:5]
if sorted_results[0][0] >= 0.99:
coeff = 0.85
last_score = sorted_results[0][0]
filtered_results = []
for score, cls in sorted_results:
if score >= last_score * coeff:
filtered_results.append((score, cls))
last_score = score
else:
break
items = [cls for score, cls in filtered_results]
probabilities = [score for score, cls in filtered_results]
sampled_item = random.choices(items, weights=probabilities, k=1)[0]
noise = torch.randn(1, noise_dim, 1, 1, device=device) # Adjust noise_dim if different
fake_images = generator(noise, embeddings[sampled_item].unsqueeze(0).unsqueeze(0))
img = fake_images.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
img = (img - img.min()) / (img.max() - img.min())
return img
iface = gr.Interface(fn=generate_image,
inputs=gr.Textbox(lines=2, placeholder="Enter Caption Here..."),
outputs=gr.Image(type="numpy"),
title="Text-to-Image Generation",
description="Enter a caption to generate an image.")
iface.launch()