import os import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from PIL import Image from transformers import T5ForConditionalGeneration, T5Tokenizer import matplotlib.pyplot as plt device ="cpu" class TextEncoder(nn.Module): def __init__(self, encoder_model_name): super(TextEncoder, self).__init__() self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name) self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name) self.encoder.to(device) def encode_text(self, text): inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) inputs = {key: value.to(device) for key, value in inputs.items()} outputs = self.encoder.encoder(**inputs) embeddings = outputs.last_hidden_state[:, 0, :] return embeddings class ConditionalDiffusionModel(nn.Module): def __init__(self): super(ConditionalDiffusionModel, self).__init__() self.model = nn.Sequential( nn.Linear(512, 768), # Adjusted from 512 to 768 nn.ReLU(), nn.Linear(768, 64), nn.ReLU(), nn.Linear(64, 64) ) def forward(self, text_embeddings): return self.model(text_embeddings) class SuperResolutionDiffusionModel(nn.Module): def __init__(self): super(SuperResolutionDiffusionModel, self).__init__() self.model = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), # 3 is the number of color channels nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 3, kernel_size=3, padding=1) ) def forward(self, input_image): return self.model(input_image) class TextToImageModel(nn.Module): def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model): super(TextToImageModel, self).__init__() self.text_encoder = text_encoder self.conditional_diffusion_model = conditional_diffusion_model self.super_resolution_diffusion_model = super_resolution_diffusion_model def forward(self, text): text_embeddings = self.text_encoder.encode_text(text) image_embeddings = self.conditional_diffusion_model(text_embeddings) input_image = torch.rand((1, 3, 64, 64)) # Initialize input image with random values for i in range(6): # Upsample the image 6 times input_image = self.super_resolution_diffusion_model(input_image) return input_image class CustomDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None): with open(annotations_file, 'r') as f: lines = f.readlines() self.img_labels = [line.strip().split(' ', 1) for line in lines] self.img_dir = img_dir self.transform = transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_name, text = self.img_labels[idx] img_path = os.path.join(self.img_dir, img_name) image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return text, image def save_checkpoint(model, optimizer, epoch, checkpoint_path): checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, checkpoint_path) def load_checkpoint(model, optimizer, checkpoint_path): if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] return epoch else: return 0 def test_inference(model, text): model.eval() with torch.no_grad(): generated_image = model(text) return generated_image def visualize_image(image_tensor): image_tensor = image_tensor.squeeze(0).cpu().detach() image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1] image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C) plt.imshow(image_tensor) plt.show() if __name__ == "__main__": # Define hyperparameters and paths batch_size = 4 learning_rate = 1e-4 num_epochs = 1000 checkpoint_path = 'checkpoint.pth' annotations_file = 'annotations.txt' img_dir = 'images/' # Initialize models text_encoder = TextEncoder("google-t5/t5-small") conditional_diffusion_model = ConditionalDiffusionModel() super_resolution_diffusion_model = SuperResolutionDiffusionModel() text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model) # Define optimizer and criterion optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate) criterion = nn.MSELoss() # Load checkpoint if available start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path) # Define transformations for the images transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), ]) # Initialize dataset and dataloader dataset = CustomDataset(annotations_file, img_dir, transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Training loop text_to_image_model.train() for epoch in range(start_epoch, num_epochs): for i, (text_batch, image_batch) in enumerate(dataloader): optimizer.zero_grad() images = text_to_image_model(text_batch) target_images = image_batch.to(device) loss = criterion(images, target_images) loss.backward() optimizer.step() print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}') save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path) print("Training completed.") # Test inference sample_text = "A big ape." generated_image = test_inference(text_to_image_model, sample_text) visualize_image(generated_image)