|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader, Dataset |
|
from PIL import Image |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import matplotlib.pyplot as plt |
|
device ="cpu" |
|
class TextEncoder(nn.Module): |
|
def __init__(self, encoder_model_name): |
|
super(TextEncoder, self).__init__() |
|
self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name) |
|
self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name) |
|
self.encoder.to(device) |
|
|
|
def encode_text(self, text): |
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
inputs = {key: value.to(device) for key, value in inputs.items()} |
|
outputs = self.encoder.encoder(**inputs) |
|
embeddings = outputs.last_hidden_state[:, 0, :] |
|
return embeddings |
|
|
|
class ConditionalDiffusionModel(nn.Module): |
|
def __init__(self): |
|
super(ConditionalDiffusionModel, self).__init__() |
|
self.model = nn.Sequential( |
|
nn.Linear(512, 768), |
|
nn.ReLU(), |
|
nn.Linear(768, 64), |
|
nn.ReLU(), |
|
nn.Linear(64, 64) |
|
) |
|
|
|
def forward(self, text_embeddings): |
|
return self.model(text_embeddings) |
|
|
|
class SuperResolutionDiffusionModel(nn.Module): |
|
def __init__(self): |
|
super(SuperResolutionDiffusionModel, self).__init__() |
|
self.model = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(64, 3, kernel_size=3, padding=1) |
|
) |
|
|
|
def forward(self, input_image): |
|
return self.model(input_image) |
|
|
|
class TextToImageModel(nn.Module): |
|
def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model): |
|
super(TextToImageModel, self).__init__() |
|
self.text_encoder = text_encoder |
|
self.conditional_diffusion_model = conditional_diffusion_model |
|
self.super_resolution_diffusion_model = super_resolution_diffusion_model |
|
|
|
def forward(self, text): |
|
text_embeddings = self.text_encoder.encode_text(text) |
|
image_embeddings = self.conditional_diffusion_model(text_embeddings) |
|
input_image = torch.rand((1, 3, 64, 64)) |
|
for i in range(6): |
|
input_image = self.super_resolution_diffusion_model(input_image) |
|
return input_image |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, annotations_file, img_dir, transform=None): |
|
with open(annotations_file, 'r') as f: |
|
lines = f.readlines() |
|
self.img_labels = [line.strip().split(' ', 1) for line in lines] |
|
self.img_dir = img_dir |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.img_labels) |
|
|
|
def __getitem__(self, idx): |
|
img_name, text = self.img_labels[idx] |
|
img_path = os.path.join(self.img_dir, img_name) |
|
image = Image.open(img_path).convert("RGB") |
|
if self.transform: |
|
image = self.transform(image) |
|
return text, image |
|
|
|
def save_checkpoint(model, optimizer, epoch, checkpoint_path): |
|
checkpoint = { |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'epoch': epoch |
|
} |
|
torch.save(checkpoint, checkpoint_path) |
|
|
|
def load_checkpoint(model, optimizer, checkpoint_path): |
|
if os.path.isfile(checkpoint_path): |
|
checkpoint = torch.load(checkpoint_path) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
epoch = checkpoint['epoch'] |
|
return epoch |
|
else: |
|
return 0 |
|
|
|
def test_inference(model, text): |
|
model.eval() |
|
with torch.no_grad(): |
|
generated_image = model(text) |
|
return generated_image |
|
|
|
def visualize_image(image_tensor): |
|
image_tensor = image_tensor.squeeze(0).cpu().detach() |
|
image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) |
|
image_tensor = image_tensor.permute(1, 2, 0) |
|
plt.imshow(image_tensor) |
|
plt.show() |
|
|
|
if __name__ == "__main__": |
|
|
|
batch_size = 4 |
|
learning_rate = 1e-4 |
|
num_epochs = 1000 |
|
checkpoint_path = 'checkpoint.pth' |
|
annotations_file = 'annotations.txt' |
|
img_dir = 'images/' |
|
|
|
|
|
text_encoder = TextEncoder("google-t5/t5-small") |
|
conditional_diffusion_model = ConditionalDiffusionModel() |
|
super_resolution_diffusion_model = SuperResolutionDiffusionModel() |
|
text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model) |
|
|
|
|
|
optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate) |
|
criterion = nn.MSELoss() |
|
|
|
|
|
start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((64, 64)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
dataset = CustomDataset(annotations_file, img_dir, transform) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
text_to_image_model.train() |
|
for epoch in range(start_epoch, num_epochs): |
|
for i, (text_batch, image_batch) in enumerate(dataloader): |
|
optimizer.zero_grad() |
|
images = text_to_image_model(text_batch) |
|
target_images = image_batch.to(device) |
|
loss = criterion(images, target_images) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}') |
|
save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path) |
|
|
|
print("Training completed.") |
|
|
|
|
|
sample_text = "A big ape." |
|
generated_image = test_inference(text_to_image_model, sample_text) |
|
visualize_image(generated_image) |
|
|