Spaces:
Runtime error
Runtime error
File size: 3,814 Bytes
1cbcd7d be37c92 1cbcd7d 78f07f4 e88f579 1cbcd7d e88f579 1cbcd7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from huggingface_hub import cached_download, hf_hub_url
from PIL import Image
import os
import gradio as gr
import spaces
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel
@spaces.GPU()
def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"):
"""Trains an image generation model on the provided dataset.
Args:
image_folder (str): Path to the folder containing training images.
text_folder (str): Path to the folder containing text prompts for each image.
model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model".
Returns:
str: Path to the saved model file.
"""
class ImageTextDataset(Dataset):
def __init__(self, image_folder, text_folder, transform=None):
self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')]
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
if self.transform:
image = self.transform(image)
with open(self.text_paths[idx], 'r') as f:
text = f.read().strip()
return image, text
# Load CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
# Define image and text transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])
# Create dataset and dataloader
dataset = ImageTextDataset(image_folder, text_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# Define optimizer and loss function
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()
# Train the model
for epoch in range(10):
for i, (images, texts) in enumerate(dataloader):
optimizer.zero_grad()
image_features = clip_model.get_image_features(images)
text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"])
similarity = image_features @ text_features.T
loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device))
loss.backward()
optimizer.step()
print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}")
# Save the trained model
model_path = os.path.join(os.getcwd(), model_name + ".pt")
torch.save(clip_model.state_dict(), model_path)
return model_path
# Define Gradio interface
iface = gr.Interface(
fn=train_image_generation_model,
inputs=[
gr.File(label="Image Folder", file_count="directory"),
gr.File(label="Text Prompts Folder", file_count="directory"),
gr.Textbox(label="Model Name"),
],
outputs=gr.File(label="Model File"),
title="Image Generation Model Trainer",
description="Upload a folder of images and their corresponding text prompts to train a model.\n Images foler should contain image files. Prompts folder should contain .txt files. Each text file is prompt for each image in images folder.",
)
iface.launch(share=True) |