Spaces:
Runtime error
Runtime error
from huggingface_hub import cached_download, hf_hub_url | |
from PIL import Image | |
import os | |
import gradio as gr | |
import spaces | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel | |
def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"): | |
"""Trains an image generation model on the provided dataset. | |
Args: | |
image_folder (str): Path to the folder containing training images. | |
text_folder (str): Path to the folder containing text prompts for each image. | |
model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model". | |
Returns: | |
str: Path to the saved model file. | |
""" | |
class ImageTextDataset(Dataset): | |
def __init__(self, image_folder, text_folder, transform=None): | |
self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')] | |
self.transform = transform | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
image = Image.open(self.image_paths[idx]).convert("RGB") | |
if self.transform: | |
image = self.transform(image) | |
with open(self.text_paths[idx], 'r') as f: | |
text = f.read().strip() | |
return image, text | |
# Load CLIP model | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
# Define image and text transformations | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) | |
]) | |
# Create dataset and dataloader | |
dataset = ImageTextDataset(image_folder, text_folder, transform=transform) | |
dataloader = DataLoader(dataset, batch_size=8, shuffle=True) | |
# Define optimizer and loss function | |
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5) | |
loss_fn = nn.CrossEntropyLoss() | |
# Train the model | |
for epoch in range(10): | |
for i, (images, texts) in enumerate(dataloader): | |
optimizer.zero_grad() | |
image_features = clip_model.get_image_features(images) | |
text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"]) | |
similarity = image_features @ text_features.T | |
loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device)) | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}") | |
# Save the trained model | |
model_path = os.path.join(os.getcwd(), model_name + ".pt") | |
torch.save(clip_model.state_dict(), model_path) | |
return model_path | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=train_image_generation_model, | |
inputs=[ | |
gr.File(label="Image Folder", file_count="directory"), | |
gr.File(label="Text Prompts Folder", file_count="directory"), | |
gr.Textbox(label="Model Name"), | |
], | |
outputs=gr.File(label="Model File"), | |
title="Image Generation Model Trainer", | |
description="Upload a folder of images and their corresponding text prompts to train a model.\n Images foler should contain image files. Prompts folder should contain .txt files. Each text file is prompt for each image in images folder.", | |
) | |
iface.launch(share=True) |