|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from PIL import Image |
|
import json |
|
import os |
|
import subprocess |
|
from transformers import BertTokenizer, BertModel |
|
import wandb |
|
|
|
|
|
LATENT_DIM = 128 |
|
HIDDEN_DIM = 256 |
|
|
|
|
|
class Text2ImageDataset(Dataset): |
|
def __init__(self, image_dir, metadata_file): |
|
self.image_dir = image_dir |
|
with open(metadata_file, 'r') as f: |
|
self.metadata = json.load(f) |
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5)) |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def __getitem__(self, idx): |
|
item = self.metadata[idx] |
|
image_path = os.path.join(self.image_dir, item['file_name']) |
|
|
|
try: |
|
image = Image.open(image_path).convert('RGBA') |
|
except FileNotFoundError: |
|
print(f"Image not found: {image_path}") |
|
return None, None |
|
except Exception as e: |
|
print(f"Error loading image {image_path}: {e}") |
|
return None, None |
|
|
|
image = self.transform(image) |
|
prompt = str(item['description']) |
|
return image, prompt |
|
|
|
|
|
class TextEncoder(nn.Module): |
|
def __init__(self, hidden_size, output_size): |
|
super(TextEncoder, self).__init__() |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.fc = nn.Linear(self.bert.config.hidden_size, output_size) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
return self.fc(outputs.last_hidden_state[:, 0, :]) |
|
|
|
|
|
class CVAE(nn.Module): |
|
def __init__(self, text_encoder): |
|
super(CVAE, self).__init__() |
|
self.text_encoder = text_encoder |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv2d(4, 32, 3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 64, 3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 128, 3, stride=2, padding=1), |
|
nn.ReLU(), |
|
nn.Flatten(), |
|
nn.Linear(128 * 4 * 4, HIDDEN_DIM) |
|
) |
|
|
|
self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
|
self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
|
|
|
|
|
self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4) |
|
self.decoder = nn.Sequential( |
|
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), |
|
nn.ReLU(), |
|
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 4, 3, stride=1, padding=1), |
|
nn.Tanh() |
|
) |
|
|
|
def encode(self, x, c): |
|
x = self.encoder(x) |
|
x = torch.cat([x, c], dim=1) |
|
mu = self.fc_mu(x) |
|
logvar = self.fc_logvar(x) |
|
return mu, logvar |
|
|
|
def decode(self, z, c): |
|
z = torch.cat([z, c], dim=1) |
|
x = self.decoder_input(z) |
|
x = x.view(-1, 128, 4, 4) |
|
return self.decoder(x) |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
def forward(self, x, c): |
|
mu, logvar = self.encode(x, c) |
|
z = self.reparameterize(mu, logvar) |
|
return self.decode(z, c), mu, logvar |
|
|
|
|
|
def loss_function(recon_x, x, mu, logvar): |
|
BCE = nn.functional.mse_loss(recon_x, x, reduction='sum') |
|
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
|
return BCE + KLD |
|
|
|
|
|
def train(model, train_loader, optimizer, device, tokenizer): |
|
model.train() |
|
train_loss = 0 |
|
for batch_idx, (data, prompt) in enumerate(train_loader): |
|
data = data.to(device) |
|
optimizer.zero_grad() |
|
|
|
encoded_input = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") |
|
input_ids = encoded_input['input_ids'].to(device) |
|
attention_mask = encoded_input['attention_mask'].to(device) |
|
|
|
text_encoding = model.text_encoder(input_ids, attention_mask) |
|
|
|
recon_batch, mu, logvar = model(data, text_encoding) |
|
loss = loss_function(recon_batch, data, mu, logvar) |
|
loss.backward() |
|
train_loss += loss.item() |
|
optimizer.step() |
|
|
|
|
|
wandb.log({ |
|
"batch_loss": loss.item(), |
|
"batch_reconstruction_loss": nn.functional.mse_loss(recon_batch, data, reduction='mean').item(), |
|
"batch_kl_divergence": (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / data.size(0)).item() |
|
}) |
|
|
|
avg_loss = train_loss / len(train_loader.dataset) |
|
return avg_loss |
|
|
|
|
|
def main(): |
|
|
|
NUM_EPOCHS = 500 |
|
BATCH_SIZE = 128 |
|
LEARNING_RATE = 1e-4 |
|
|
|
|
|
SAVE_INTERVAL = 25 |
|
SAVE_INTERVAL_IMAGE = 1 |
|
PROJECT_NAME = "BitRoss" |
|
MODEL_NAME = "BitRoss" |
|
SAVE_DIR = "/models/BitRoss/" |
|
|
|
if(os.path.exists(SAVE_DIR) == False): |
|
os.makedirs(SAVE_DIR) |
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
if not os.path.exists(SAVE_DIR): |
|
os.makedirs(SAVE_DIR) |
|
|
|
DATA_DIR = "./trainingData/" |
|
METADATA_FILE = "./trainingData/metadata.json" |
|
|
|
|
|
|
|
wandb.init(project=PROJECT_NAME, config={ |
|
"LATENT_DIM": LATENT_DIM, |
|
"HIDDEN_DIM": HIDDEN_DIM, |
|
"NUM_EPOCHS": NUM_EPOCHS, |
|
"BATCH_SIZE": BATCH_SIZE, |
|
"LEARNING_RATE": LEARNING_RATE, |
|
"SAVE_INTERVAL": SAVE_INTERVAL, |
|
"MODEL_NAME": MODEL_NAME |
|
}) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
dataset = Text2ImageDataset(DATA_DIR, METADATA_FILE) |
|
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
|
text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM) |
|
model = CVAE(text_encoder).to(device) |
|
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) |
|
|
|
|
|
wandb.watch(model, log="all", log_freq=100) |
|
|
|
for epoch in range(1, NUM_EPOCHS + 1): |
|
train_loss = train(model, train_loader, optimizer, device, tokenizer) |
|
print(f'Epoch {epoch}, Loss: {train_loss:.4f}') |
|
|
|
|
|
wandb.log({ |
|
"epoch": epoch, |
|
"train_loss": train_loss, |
|
}) |
|
|
|
|
|
if epoch % SAVE_INTERVAL_IMAGE == 0: |
|
|
|
output_image = f"{SAVE_DIR}output_epoch_{epoch}.png" |
|
|
|
|
|
from generate import generate_image |
|
prompt = "A blue sword made of diamond" |
|
generated_image = generate_image(model, prompt, device) |
|
generated_image.save(output_image) |
|
|
|
|
|
wandb.log({ |
|
"generated_image": wandb.Image(output_image, caption=f"Generated at epoch {epoch} with prompt {prompt}") |
|
}) |
|
|
|
|
|
if epoch % SAVE_INTERVAL == 0: |
|
model_save_path = f"{SAVE_DIR}{MODEL_NAME}_epoch_{epoch}.pth" |
|
torch.save(model.state_dict(), model_save_path) |
|
print(f"Model saved to {model_save_path}") |
|
|
|
|
|
if epoch % 10 == 0: |
|
model.eval() |
|
with torch.no_grad(): |
|
sample_data, sample_prompt = next(iter(train_loader)) |
|
sample_data = sample_data[:4].to(device) |
|
encoded_input = tokenizer(sample_prompt[:4], padding=True, truncation=True, return_tensors="pt") |
|
input_ids = encoded_input['input_ids'].to(device) |
|
attention_mask = encoded_input['attention_mask'].to(device) |
|
text_encoding = model.text_encoder(input_ids, attention_mask) |
|
recon_batch, _, _ = model(sample_data, text_encoding) |
|
|
|
|
|
original_images = [transforms.ToPILImage()((sample_data[i] * 0.5 + 0.5).cpu()) for i in range(4)] |
|
reconstructed_images = [transforms.ToPILImage()((recon_batch[i] * 0.5 + 0.5).cpu()) for i in range(4)] |
|
|
|
wandb.log({ |
|
f"original_vs_reconstructed_{i}": [wandb.Image(original_images[i], caption=f"Original {i}"), |
|
wandb.Image(reconstructed_images[i], caption=f"Reconstructed {i}")] |
|
for i in range(4) |
|
}) |
|
|
|
wandb.finish() |
|
|
|
if __name__ == "__main__": |
|
main() |