|
import os |
|
import requests |
|
import tarfile |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import json |
|
import math |
|
from tqdm import tqdm |
|
from transformers import BertTokenizer, BertModel |
|
import gradio as gr |
|
|
|
|
|
class Config: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
image_size = 64 |
|
batch_size = 32 |
|
num_epochs = 50 |
|
learning_rate = 1e-4 |
|
timesteps = 1000 |
|
text_embed_dim = 768 |
|
num_images_options = [1, 4, 6] |
|
|
|
|
|
coco_images_url = "http://images.cocodataset.org/zips/train2017.zip" |
|
coco_annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip" |
|
data_dir = "./coco_data" |
|
images_dir = os.path.join(data_dir, "train2017") |
|
annotations_path = os.path.join(data_dir, "annotations/instances_train2017.json") |
|
|
|
def __init__(self): |
|
os.makedirs(self.data_dir, exist_ok=True) |
|
|
|
config = Config() |
|
|
|
|
|
def download_and_extract_coco(): |
|
if os.path.exists(config.images_dir) and os.path.exists(config.annotations_path): |
|
print("COCO dataset already downloaded") |
|
return |
|
|
|
print("Downloading COCO dataset...") |
|
|
|
|
|
images_zip_path = os.path.join(config.data_dir, "train2017.zip") |
|
if not os.path.exists(images_zip_path): |
|
response = requests.get(config.coco_images_url, stream=True) |
|
with open(images_zip_path, "wb") as f: |
|
for chunk in tqdm(response.iter_content(chunk_size=1024)): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
|
|
annotations_zip_path = os.path.join(config.data_dir, "annotations_trainval2017.zip") |
|
if not os.path.exists(annotations_zip_path): |
|
response = requests.get(config.coco_annotations_url, stream=True) |
|
with open(annotations_zip_path, "wb") as f: |
|
for chunk in tqdm(response.iter_content(chunk_size=1024)): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
|
|
print("Extracting images...") |
|
with tarfile.open(images_zip_path, "r:zip") as tar: |
|
tar.extractall(config.data_dir) |
|
|
|
print("Extracting annotations...") |
|
with tarfile.open(annotations_zip_path, "r:zip") as tar: |
|
tar.extractall(config.data_dir) |
|
|
|
print("COCO dataset ready") |
|
|
|
download_and_extract_coco() |
|
|
|
|
|
class TextEncoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.model = BertModel.from_pretrained('bert-base-uncased') |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, texts): |
|
inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=64) |
|
inputs = {k: v.to(config.device) for k, v in inputs.items()} |
|
outputs = self.model(**inputs) |
|
return outputs.last_hidden_state[:, 0, :] |
|
|
|
text_encoder = TextEncoder().to(config.device) |
|
|
|
|
|
class ConditionalUNet(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
|
self.down1 = DownBlock(64, 128) |
|
self.down2 = DownBlock(128, 256) |
|
|
|
self.text_proj = nn.Linear(config.text_embed_dim, 256) |
|
self.merge = nn.Linear(256 + 256, 256) |
|
|
|
self.up1 = UpBlock(256, 128) |
|
self.up2 = UpBlock(128, 64) |
|
self.final = nn.Conv2d(64, 3, kernel_size=3, padding=1) |
|
|
|
def forward(self, x, t, text_emb): |
|
x1 = F.relu(self.conv1(x)) |
|
x2 = self.down1(x1) |
|
x3 = self.down2(x2) |
|
|
|
text_emb = self.text_proj(text_emb) |
|
text_emb = text_emb.unsqueeze(-1).unsqueeze(-1) |
|
text_emb = text_emb.expand(-1, -1, x3.size(2), x3.size(3)) |
|
|
|
x = torch.cat([x3, text_emb], dim=1) |
|
b, c, h, w = x.shape |
|
x = x.permute(0, 2, 3, 1).reshape(b*h*w, c) |
|
x = self.merge(x) |
|
x = x.reshape(b, h, w, 256).permute(0, 3, 1, 2) |
|
|
|
x = self.up1(x) |
|
x = self.up2(x) |
|
return self.final(x) |
|
|
|
class DownBlock(nn.Module): |
|
def __init__(self, in_ch, out_ch): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_ch), |
|
nn.ReLU(), |
|
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_ch), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2) |
|
) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class UpBlock(nn.Module): |
|
def __init__(self, in_ch, out_ch): |
|
super().__init__() |
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_ch), |
|
nn.ReLU(), |
|
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_ch), |
|
nn.ReLU() |
|
) |
|
|
|
def forward(self, x): |
|
x = self.up(x) |
|
return self.conv(x) |
|
|
|
|
|
betas = linear_beta_schedule(config.timesteps).to(config.device) |
|
alphas = 1. - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) |
|
|
|
def linear_beta_schedule(timesteps): |
|
beta_start = 0.0001 |
|
beta_end = 0.02 |
|
return torch.linspace(beta_start, beta_end, timesteps) |
|
|
|
def forward_diffusion_sample(x_0, t, device=config.device): |
|
noise = torch.randn_like(x_0) |
|
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) |
|
sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) |
|
return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise |
|
|
|
|
|
class CocoDataset(Dataset): |
|
def __init__(self, root_dir, annotations_file, transform=None): |
|
self.root_dir = root_dir |
|
self.transform = transform |
|
|
|
with open(annotations_file, 'r') as f: |
|
data = json.load(f) |
|
|
|
self.images = [] |
|
self.captions = [] |
|
|
|
image_id_to_captions = {} |
|
for ann in data['annotations']: |
|
if ann['image_id'] not in image_id_to_captions: |
|
image_id_to_captions[ann['image_id']] = [] |
|
image_id_to_captions[ann['image_id']].append(ann['caption']) |
|
|
|
for img in data['images']: |
|
if img['id'] in image_id_to_captions: |
|
self.images.append(img) |
|
self.captions.append(image_id_to_captions[img['id']][0]) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.root_dir, self.images[idx]['file_name']) |
|
image = Image.open(img_path).convert('RGB') |
|
caption = self.captions[idx] |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
return image, caption |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((config.image_size, config.image_size)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
]) |
|
|
|
|
|
model = ConditionalUNet().to(config.device) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) |
|
|
|
|
|
def train(): |
|
dataset = CocoDataset(config.images_dir, config.annotations_path, transform) |
|
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) |
|
|
|
for epoch in range(config.num_epochs): |
|
for batch_idx, (images, captions) in enumerate(tqdm(dataloader)): |
|
images = images.to(config.device) |
|
|
|
|
|
text_emb = text_encoder(captions) |
|
|
|
|
|
t = torch.randint(0, config.timesteps, (images.size(0),), device=config.device) |
|
|
|
|
|
x_noisy, noise = forward_diffusion_sample(images, t) |
|
|
|
|
|
pred_noise = model(x_noisy, t, text_emb) |
|
|
|
|
|
loss = F.mse_loss(pred_noise, noise) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if batch_idx % 100 == 0: |
|
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}") |
|
|
|
|
|
torch.save(model.state_dict(), f"model_epoch_{epoch}.pth") |
|
|
|
|
|
@torch.no_grad() |
|
def generate(prompt, num_images=1): |
|
model.eval() |
|
num_images = int(num_images) |
|
|
|
text_emb = text_encoder([prompt]*num_images) |
|
x = torch.randn((num_images, 3, config.image_size, config.image_size)).to(config.device) |
|
|
|
for t in reversed(range(config.timesteps)): |
|
t_tensor = torch.full((num_images,), t, device=config.device) |
|
pred_noise = model(x, t_tensor, text_emb) |
|
alpha_t = alphas[t].view(1, 1, 1, 1) |
|
alpha_cumprod_t = alphas_cumprod[t].view(1, 1, 1, 1) |
|
beta_t = betas[t].view(1, 1, 1, 1) |
|
|
|
if t > 0: |
|
noise = torch.randn_like(x) |
|
else: |
|
noise = torch.zeros_like(x) |
|
|
|
x = (1 / torch.sqrt(alpha_t)) * ( |
|
x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise |
|
) + torch.sqrt(beta_t) * noise |
|
|
|
x = torch.clamp(x, -1, 1) |
|
x = (x + 1) / 2 |
|
|
|
images = [] |
|
for img in x: |
|
img = transforms.ToPILImage()(img.cpu()) |
|
images.append(img) |
|
|
|
return images |
|
|
|
|
|
def generate_and_display(prompt, num_images): |
|
images = generate(prompt, num_images) |
|
|
|
fig, axes = plt.subplots(1, len(images), figsize=(5*len(images), 5)) |
|
if len(images) == 1: |
|
axes.imshow(images[0]) |
|
axes.axis('off') |
|
else: |
|
for ax, img in zip(axes, images): |
|
ax.imshow(img) |
|
ax.axis('off') |
|
plt.tight_layout() |
|
return fig |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## GPUDiff-V1: diffussion powerful image generator!") |
|
with gr.Row(): |
|
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter image description...") |
|
num_select = gr.Dropdown(choices=config.num_images_options, value=1, label="Number of images") |
|
generate_btn = gr.Button("Generate") |
|
output = gr.Plot() |
|
|
|
generate_btn.click( |
|
fn=generate_and_display, |
|
inputs=[prompt_input, num_select], |
|
outputs=output |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
train() |
|
|
|
demo.launch() |