import os import torch import torch.nn as nn import numpy as np import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset, Subset from PIL import Image from datasets import load_dataset from transformers import T5EncoderModel, T5Tokenizer, DistilBertModel, DistilBertTokenizer import matplotlib.pyplot as plt from rich import print as rp from credits import HUGGINGFACE_TOKEN, HUGGINGFACE_W_TOKEN, WANDB_API_KEY import wandb # Import wandb import torchvision.utils as vutils # To save image grids write_token = HUGGINGFACE_W_TOKEN read_token = HUGGINGFACE_TOKEN class TextEncoder(nn.Module): def __init__(self, encoder_model_name, encoder_type="t5", device='cpu'): super(TextEncoder, self).__init__() self.device = device self.encoder_type = encoder_type if encoder_type == "t5": self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token) self.encoder = T5EncoderModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token) elif encoder_type == "distilbert": self.tokenizer = DistilBertTokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token) self.encoder = DistilBertModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token) else: raise ValueError(f"Invalid encoder_type: {encoder_type}. Choose from 't5' or 'distilbert'.") self.encoder.to(self.device) def encode_text(self, text): if isinstance(text, str): text = [text] embeddings = [] for t in text: inputs = self.tokenizer(t, return_tensors="pt", padding=True, truncation=True).to(self.device) outputs = self.encoder(**inputs) embeddings.append(outputs.last_hidden_state[:, 0, :]) return torch.stack(embeddings) # Combine embeddings into a batch class ConditionalDiffusionModel(nn.Module): def __init__(self): super(ConditionalDiffusionModel, self).__init__() self.model = nn.Sequential( nn.Linear(512, 768), # Adjusted from 512 to 768 nn.ReLU(), nn.Linear(768, 64), nn.ReLU(), nn.Linear(64, 64) ) def forward(self, text_embeddings): return self.model(text_embeddings) class SuperResolutionDiffusionModel(nn.Module): def __init__(self): super(SuperResolutionDiffusionModel, self).__init__() self.model = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), # 3 is the number of color channels nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 3, kernel_size=3, padding=1) ) def forward(self, input_image): return self.model(input_image) class TextToImageModel(nn.Module): def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model): super(TextToImageModel, self).__init__() self.text_encoder = text_encoder self.conditional_diffusion_model = conditional_diffusion_model self.super_resolution_diffusion_model = super_resolution_diffusion_model def forward(self, text): text_embeddings = self.text_encoder.encode_text(text) image_embeddings = self.conditional_diffusion_model(text_embeddings) input_image = torch.rand((image_embeddings.shape[0], 3, 128, 128)).to(text_embeddings.device) for _ in range(6): # Upsample the image 6 times input_image = self.super_resolution_diffusion_model(input_image) return input_image class CustomDataset(Dataset): def __init__(self, annotations_file, train_img_dir, size_sqr=128): with open(annotations_file, 'r') as f: lines = f.readlines() self.transform = transforms.Compose([ transforms.Resize((size_sqr, size_sqr)), transforms.ToTensor(), ]) self.img_labels = [line.strip().split(' ', 1) for line in lines] self.train_img_dir = train_img_dir def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_name, text = self.img_labels[idx] img_path = os.path.join(self.train_img_dir, img_name) image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return text, image class HuggingDataset(Dataset): # New class for HuggingFace dataset def __init__(self, dataset_name="vera365/lexica_dataset", size_sqr=128, limit=None): self.dataset = load_dataset(dataset_name, token=read_token, cache_dir='./datasets') self.transform = transforms.Compose([ transforms.Resize((size_sqr, size_sqr)), transforms.ToTensor(), ]) # Apply limit if specified if limit is not None: self.dataset = Subset(self.dataset["train"], range(limit)) def __len__(self): return len(self.dataset["train"]) def __getitem__(self, idx): item = self.dataset["train"][idx] image =item["image"].convert("RGB") text = item["subject"] if self.transform: image = self.transform(image) return text, image class StorageHandler: def __init__(self, storage_dir="./image_gen_storage", hub_model_name="K00B404/tiny_image_gen", push_dataset=False, dataset_name="K00B404/custom_image_descriptions_dataset"): self.model_name = hub_model_name self.dataset_name = dataset_name self.push_dataset = push_dataset self.storage_dir = storage_dir def save_checkpoint(self, model, optimizer, scheduler, epoch, checkpoint_path): checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch } torch.save(checkpoint, checkpoint_path) def load_checkpoint(self, model, optimizer, scheduler, checkpoint_path): if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) epoch = checkpoint['epoch'] return epoch, scheduler else: return 0, scheduler def push_dataset(self, dataset): if dataset: dataset.push_to_hub(self.dataset_name, token=write_token) def push(self, model, tokenizer): model.push_to_hub(self.model_name, token=write_token) tokenizer.push_to_hub(self.model_name, token=write_token) # Optionally push dataset to Hugging Face Hub class Common: def __init__(self, device='cpu', wandb_log=False): self.wandb_log = wandb_log self.device = device self.terminal_log = rp if self.wandb_log: # Initialize wandb #self.wandb = wandb.login(key=WANDB_API_KEY) # Assuming you have already logged in. If not, use: wandb.login(key='YOUR_WANDB_API_KEY') self.wandb = wandb.init(project="my-image-generation-project", config={ "learning_rate": learning_rate, "batch_size": batch_size, "num_epochs": num_epochs, "encoder_model": encoder }) def train(self, model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_path): for epoch in range(start_epoch, num_epochs): model.train() for i, (text_batch, image_batch) in enumerate(dataloader): image_batch = image_batch.to(self.device) image_size = 128 # Increase image size # Generate a starting image with the correct shape input_image = torch.randn((image_batch.shape[0], 3, image_size//8, image_size//8)).to(device) optimizer.zero_grad() images = model(text_batch) loss = criterion(images, image_batch) loss.backward() optimizer.step() if self.wandb_log: # Log loss and learning rate self.wandb.log({"train_loss": loss.item(), "lr": optimizer.param_groups[0]['lr']}) if i % 25 == 0: # Save a grid of real and generated images for monitoring img_grid_real = vutils.make_grid(image_batch[:4], padding=2, normalize=True) img_grid_fake = vutils.make_grid(input_image[:4], padding=2, normalize=True) plt.figure(figsize=(15,15)) plt.subplot(1,2,1) plt.axis("off") plt.title("Real Images") plt.imshow(np.transpose(img_grid_real.cpu(),(1,2,0))) plt.subplot(1,2,2) plt.axis("off") plt.title("Generated Images") plt.imshow(np.transpose(img_grid_fake.cpu(),(1,2,0))) plt.savefig(f'generated_images_epoch_{epoch+1}_batch_{i}.png') plt.close() # Validation step val_loss = self.evaluate(model, dataloader, criterion) scheduler.step(val_loss) # Update scheduler with validation loss image = self.test_inference(model, "A house next to a river.") self.visualize_image(image, f'generated_image_epoch_{epoch + 1}.png') self.terminal_log(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Loss: {val_loss}') StorageHandler().save_checkpoint(model, optimizer, scheduler, epoch + 1, checkpoint_path) self.terminal_log("Training completed.") def evaluate(self, model, dataloader, criterion): model.eval() total_loss = 0 with torch.no_grad(): for i, (text_batch, image_batch) in enumerate(dataloader): image_batch = image_batch.to(self.device) images = model(text_batch) loss = criterion(images, image_batch) total_loss += loss.item() avg_loss = total_loss / len(dataloader) self.terminal_log(f'Validation Loss: {avg_loss}') return avg_loss def test_inference(self, model, text): model.eval() with torch.no_grad(): if isinstance(text, str): generated_image = model(text) else: generated_image = [model(t) for t in text] return generated_image def visualize_image(self, image_tensor, filename='generated_image.png'): image_tensor = image_tensor.squeeze(0).cpu().detach() image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1] image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C) plt.imshow(image_tensor) plt.savefig(filename) #plt.show() if __name__ == "__main__": batch_size = 1 learning_rate = 1e-4 num_epochs = 500 encoder = "google-t5/t5-small" checkpoint_path = './models/image_gen' os.makedirs(checkpoint_path, exist_ok=True) checkpoint_file = f"{checkpoint_path}/checkpoint_backup.pth" use_huggingface_dataset = False # <-- Toggle between datasets limit_huggingface_dataset = 1000 # <-- Set the limit for HuggingFace dataset train_img_dir = './train_images' annotations_file = f'{train_img_dir}/annotations.txt' storage_dir = "./image_gen_storage" os.makedirs(storage_dir, exist_ok=True) hub_model_name = "K00B404/tiny_image_gen" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") text_encoder = TextEncoder(encoder, encoder_type="t5", device=device) conditional_diffusion_model = ConditionalDiffusionModel() super_resolution_diffusion_model = SuperResolutionDiffusionModel() text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model) text_to_image_model.to(device) optimizer = optim.AdamW(text_to_image_model.parameters(), lr=learning_rate) criterion = nn.MSELoss() scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.001, patience=2) common = Common(device, True) storage = StorageHandler(checkpoint_path, hub_model_name) start_epoch, scheduler = storage.load_checkpoint(text_to_image_model, optimizer, scheduler, checkpoint_file) if use_huggingface_dataset: dataset = HuggingDataset(size_sqr=128, limit=limit_huggingface_dataset) else: dataset = CustomDataset(annotations_file, train_img_dir, size_sqr=128) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) common.train(text_to_image_model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_file) sample_texts = ["A big ape.", "A yellow banana."] for sample_text in sample_texts: generated_image = common.test_inference(text_to_image_model, sample_text) common.visualize_image(generated_image)