yesying / app.py
K00B404's picture
Update app.py
bfcb186 verified
raw
history blame contribute delete
No virus
6.37 kB
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from transformers import T5ForConditionalGeneration, T5Tokenizer
import matplotlib.pyplot as plt
device ="cpu"
class TextEncoder(nn.Module):
def __init__(self, encoder_model_name):
super(TextEncoder, self).__init__()
self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name)
self.encoder = T5ForConditionalGeneration.from_pretrained(encoder_model_name)
self.encoder.to(device)
def encode_text(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}
outputs = self.encoder.encoder(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :]
return embeddings
class ConditionalDiffusionModel(nn.Module):
def __init__(self):
super(ConditionalDiffusionModel, self).__init__()
self.model = nn.Sequential(
nn.Linear(512, 768), # Adjusted from 512 to 768
nn.ReLU(),
nn.Linear(768, 64),
nn.ReLU(),
nn.Linear(64, 64)
)
def forward(self, text_embeddings):
return self.model(text_embeddings)
class SuperResolutionDiffusionModel(nn.Module):
def __init__(self):
super(SuperResolutionDiffusionModel, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 3 is the number of color channels
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1)
)
def forward(self, input_image):
return self.model(input_image)
class TextToImageModel(nn.Module):
def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model):
super(TextToImageModel, self).__init__()
self.text_encoder = text_encoder
self.conditional_diffusion_model = conditional_diffusion_model
self.super_resolution_diffusion_model = super_resolution_diffusion_model
def forward(self, text):
text_embeddings = self.text_encoder.encode_text(text)
image_embeddings = self.conditional_diffusion_model(text_embeddings)
input_image = torch.rand((1, 3, 64, 64)) # Initialize input image with random values
for i in range(6): # Upsample the image 6 times
input_image = self.super_resolution_diffusion_model(input_image)
return input_image
class CustomDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None):
with open(annotations_file, 'r') as f:
lines = f.readlines()
self.img_labels = [line.strip().split(' ', 1) for line in lines]
self.img_dir = img_dir
self.transform = transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_name, text = self.img_labels[idx]
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return text, image
def save_checkpoint(model, optimizer, epoch, checkpoint_path):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, checkpoint_path)
def load_checkpoint(model, optimizer, checkpoint_path):
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
return epoch
else:
return 0
def test_inference(model, text):
model.eval()
with torch.no_grad():
generated_image = model(text)
return generated_image
def visualize_image(image_tensor):
image_tensor = image_tensor.squeeze(0).cpu().detach()
image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1]
image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C)
plt.imshow(image_tensor)
plt.show()
if __name__ == "__main__":
# Define hyperparameters and paths
batch_size = 4
learning_rate = 1e-4
num_epochs = 1000
checkpoint_path = 'checkpoint.pth'
annotations_file = 'annotations.txt'
img_dir = 'images/'
# Initialize models
text_encoder = TextEncoder("google-t5/t5-small")
conditional_diffusion_model = ConditionalDiffusionModel()
super_resolution_diffusion_model = SuperResolutionDiffusionModel()
text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
# Define optimizer and criterion
optimizer = torch.optim.Adam(text_to_image_model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# Load checkpoint if available
start_epoch = load_checkpoint(text_to_image_model, optimizer, checkpoint_path)
# Define transformations for the images
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
# Initialize dataset and dataloader
dataset = CustomDataset(annotations_file, img_dir, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Training loop
text_to_image_model.train()
for epoch in range(start_epoch, num_epochs):
for i, (text_batch, image_batch) in enumerate(dataloader):
optimizer.zero_grad()
images = text_to_image_model(text_batch)
target_images = image_batch.to(device)
loss = criterion(images, target_images)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')
save_checkpoint(text_to_image_model, optimizer, epoch+1, checkpoint_path)
print("Training completed.")
# Test inference
sample_text = "A big ape."
generated_image = test_inference(text_to_image_model, sample_text)
visualize_image(generated_image)