yesying / app.py
K00B404's picture
Update app.py
454a3ac verified
raw
history blame
No virus
13.5 kB
import os
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from PIL import Image
from datasets import load_dataset
from transformers import T5EncoderModel, T5Tokenizer, DistilBertModel, DistilBertTokenizer
import matplotlib.pyplot as plt
from rich import print as rp
from credits import HUGGINGFACE_TOKEN, HUGGINGFACE_W_TOKEN, WANDB_API_KEY
import wandb # Import wandb
import torchvision.utils as vutils # To save image grids
write_token = HUGGINGFACE_W_TOKEN
read_token = HUGGINGFACE_TOKEN
class TextEncoder(nn.Module):
def __init__(self, encoder_model_name, encoder_type="t5", device='cpu'):
super(TextEncoder, self).__init__()
self.device = device
self.encoder_type = encoder_type
if encoder_type == "t5":
self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
self.encoder = T5EncoderModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
elif encoder_type == "distilbert":
self.tokenizer = DistilBertTokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
self.encoder = DistilBertModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
else:
raise ValueError(f"Invalid encoder_type: {encoder_type}. Choose from 't5' or 'distilbert'.")
self.encoder.to(self.device)
def encode_text(self, text):
if isinstance(text, str):
text = [text]
embeddings = []
for t in text:
inputs = self.tokenizer(t, return_tensors="pt", padding=True, truncation=True).to(self.device)
outputs = self.encoder(**inputs)
embeddings.append(outputs.last_hidden_state[:, 0, :])
return torch.stack(embeddings) # Combine embeddings into a batch
class ConditionalDiffusionModel(nn.Module):
def __init__(self):
super(ConditionalDiffusionModel, self).__init__()
self.model = nn.Sequential(
nn.Linear(512, 768), # Adjusted from 512 to 768
nn.ReLU(),
nn.Linear(768, 64),
nn.ReLU(),
nn.Linear(64, 64)
)
def forward(self, text_embeddings):
return self.model(text_embeddings)
class SuperResolutionDiffusionModel(nn.Module):
def __init__(self):
super(SuperResolutionDiffusionModel, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 3 is the number of color channels
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1)
)
def forward(self, input_image):
return self.model(input_image)
class TextToImageModel(nn.Module):
def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model):
super(TextToImageModel, self).__init__()
self.text_encoder = text_encoder
self.conditional_diffusion_model = conditional_diffusion_model
self.super_resolution_diffusion_model = super_resolution_diffusion_model
def forward(self, text):
text_embeddings = self.text_encoder.encode_text(text)
image_embeddings = self.conditional_diffusion_model(text_embeddings)
input_image = torch.rand((image_embeddings.shape[0], 3, 128, 128)).to(text_embeddings.device)
for _ in range(6): # Upsample the image 6 times
input_image = self.super_resolution_diffusion_model(input_image)
return input_image
class CustomDataset(Dataset):
def __init__(self, annotations_file, train_img_dir, size_sqr=128):
with open(annotations_file, 'r') as f:
lines = f.readlines()
self.transform = transforms.Compose([
transforms.Resize((size_sqr, size_sqr)),
transforms.ToTensor(),
])
self.img_labels = [line.strip().split(' ', 1) for line in lines]
self.train_img_dir = train_img_dir
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_name, text = self.img_labels[idx]
img_path = os.path.join(self.train_img_dir, img_name)
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return text, image
class HuggingDataset(Dataset): # New class for HuggingFace dataset
def __init__(self, dataset_name="vera365/lexica_dataset", size_sqr=128, limit=None):
self.dataset = load_dataset(dataset_name, token=read_token, cache_dir='./datasets')
self.transform = transforms.Compose([
transforms.Resize((size_sqr, size_sqr)),
transforms.ToTensor(),
])
# Apply limit if specified
if limit is not None:
self.dataset = Subset(self.dataset["train"], range(limit))
def __len__(self):
return len(self.dataset["train"])
def __getitem__(self, idx):
item = self.dataset["train"][idx]
image =item["image"].convert("RGB")
text = item["subject"]
if self.transform:
image = self.transform(image)
return text, image
class StorageHandler:
def __init__(self, storage_dir="./image_gen_storage", hub_model_name="K00B404/tiny_image_gen", push_dataset=False, dataset_name="K00B404/custom_image_descriptions_dataset"):
self.model_name = hub_model_name
self.dataset_name = dataset_name
self.push_dataset = push_dataset
self.storage_dir = storage_dir
def save_checkpoint(self, model, optimizer, scheduler, epoch, checkpoint_path):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, checkpoint_path)
def load_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
return epoch, scheduler
else:
return 0, scheduler
def push_dataset(self, dataset):
if dataset:
dataset.push_to_hub(self.dataset_name, token=write_token)
def push(self, model, tokenizer):
model.push_to_hub(self.model_name, token=write_token)
tokenizer.push_to_hub(self.model_name, token=write_token)
# Optionally push dataset to Hugging Face Hub
class Common:
def __init__(self, device='cpu', wandb_log=False):
self.wandb_log = wandb_log
self.device = device
self.terminal_log = rp
if self.wandb_log:
# Initialize wandb
#self.wandb = wandb.login(key=WANDB_API_KEY) # Assuming you have already logged in. If not, use: wandb.login(key='YOUR_WANDB_API_KEY')
self.wandb = wandb.init(project="my-image-generation-project",
config={
"learning_rate": learning_rate,
"batch_size": batch_size,
"num_epochs": num_epochs,
"encoder_model": encoder
})
def train(self, model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_path):
for epoch in range(start_epoch, num_epochs):
model.train()
for i, (text_batch, image_batch) in enumerate(dataloader):
image_batch = image_batch.to(self.device)
image_size = 128 # Increase image size
# Generate a starting image with the correct shape
input_image = torch.randn((image_batch.shape[0], 3, image_size//8, image_size//8)).to(device)
optimizer.zero_grad()
images = model(text_batch)
loss = criterion(images, image_batch)
loss.backward()
optimizer.step()
if self.wandb_log:
# Log loss and learning rate
self.wandb.log({"train_loss": loss.item(), "lr": optimizer.param_groups[0]['lr']})
if i % 25 == 0:
# Save a grid of real and generated images for monitoring
img_grid_real = vutils.make_grid(image_batch[:4], padding=2, normalize=True)
img_grid_fake = vutils.make_grid(input_image[:4], padding=2, normalize=True)
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(img_grid_real.cpu(),(1,2,0)))
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(img_grid_fake.cpu(),(1,2,0)))
plt.savefig(f'generated_images_epoch_{epoch+1}_batch_{i}.png')
plt.close()
# Validation step
val_loss = self.evaluate(model, dataloader, criterion)
scheduler.step(val_loss) # Update scheduler with validation loss
image = self.test_inference(model, "A house next to a river.")
self.visualize_image(image, f'generated_image_epoch_{epoch + 1}.png')
self.terminal_log(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Loss: {val_loss}')
StorageHandler().save_checkpoint(model, optimizer, scheduler, epoch + 1, checkpoint_path)
self.terminal_log("Training completed.")
def evaluate(self, model, dataloader, criterion):
model.eval()
total_loss = 0
with torch.no_grad():
for i, (text_batch, image_batch) in enumerate(dataloader):
image_batch = image_batch.to(self.device)
images = model(text_batch)
loss = criterion(images, image_batch)
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
self.terminal_log(f'Validation Loss: {avg_loss}')
return avg_loss
def test_inference(self, model, text):
model.eval()
with torch.no_grad():
if isinstance(text, str):
generated_image = model(text)
else:
generated_image = [model(t) for t in text]
return generated_image
def visualize_image(self, image_tensor, filename='generated_image.png'):
image_tensor = image_tensor.squeeze(0).cpu().detach()
image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1]
image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C)
plt.imshow(image_tensor)
plt.savefig(filename)
#plt.show()
if __name__ == "__main__":
batch_size = 1
learning_rate = 1e-4
num_epochs = 500
encoder = "google-t5/t5-small"
checkpoint_path = './models/image_gen'
os.makedirs(checkpoint_path, exist_ok=True)
checkpoint_file = f"{checkpoint_path}/checkpoint_backup.pth"
use_huggingface_dataset = False # <-- Toggle between datasets
limit_huggingface_dataset = 1000 # <-- Set the limit for HuggingFace dataset
train_img_dir = './train_images'
annotations_file = f'{train_img_dir}/annotations.txt'
storage_dir = "./image_gen_storage"
os.makedirs(storage_dir, exist_ok=True)
hub_model_name = "K00B404/tiny_image_gen"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_encoder = TextEncoder(encoder, encoder_type="t5", device=device)
conditional_diffusion_model = ConditionalDiffusionModel()
super_resolution_diffusion_model = SuperResolutionDiffusionModel()
text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
text_to_image_model.to(device)
optimizer = optim.AdamW(text_to_image_model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.001, patience=2)
common = Common(device, True)
storage = StorageHandler(checkpoint_path, hub_model_name)
start_epoch, scheduler = storage.load_checkpoint(text_to_image_model, optimizer, scheduler, checkpoint_file)
if use_huggingface_dataset:
dataset = HuggingDataset(size_sqr=128, limit=limit_huggingface_dataset)
else:
dataset = CustomDataset(annotations_file, train_img_dir, size_sqr=128)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
common.train(text_to_image_model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_file)
sample_texts = ["A big ape.", "A yellow banana."]
for sample_text in sample_texts:
generated_image = common.test_inference(text_to_image_model, sample_text)
common.visualize_image(generated_image)