Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from datasets import load_dataset | |
from huggingface_hub import Repository | |
from huggingface_hub import HfApi, HfFolder, Repository, create_repo | |
import os | |
import pandas as pd | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from small_256_model import UNet as small_UNet | |
from big_1024_model import UNet as big_UNet | |
from CLIP import load as load_clip | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
big = True if device == torch.device('cpu') else False | |
# Parameters | |
IMG_SIZE = 1024 if big else 256 | |
BATCH_SIZE = 1 if big else 4 | |
EPOCHS = 12 | |
LR = 0.0002 | |
dataset_id = "K00B404/pix2pix_flux_set" | |
model_repo_id = "K00B404/pix2pix_flux" | |
# Global model variable | |
global_model = None | |
# clip | |
clip_model,clip_tokenizer = load_clip() | |
def load_model(): | |
"""Load the models at startup""" | |
global global_model | |
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' | |
try: | |
checkpoint = torch.load(weights_name, map_location=device) | |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() | |
global_model = model | |
print("Model loaded successfully!") | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
model = big_UNet().to(device) if big else small_UNet().to(device) | |
global_model = model | |
return model | |
import os | |
import pandas as pd | |
class Pix2PixDataset(torch.utils.data.Dataset): | |
def __init__(self, ds, transform, clip_tokenizer, csv_path='combined_data.csv'): | |
if not os.path.exists(csv_path): | |
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv') | |
self.data = pd.read_csv(csv_path) | |
self.clip_tokenizer = clip_tokenizer | |
self.originals = [x for x in ds["train"] if x['label'] == 0] | |
self.targets = [x for x in ds["train"] if x['label'] == 1] | |
assert len(self.originals) == len(self.targets) | |
print(f"Number of original images: {len(self.originals)}") | |
print(f"Number of target images: {len(self.targets)}") | |
self.transform = transform | |
def __len__(self): | |
return len(self.originals) | |
def __getitem__(self, idx): | |
# Get original and target images | |
original_img = self.originals[idx]['image'] | |
target_img = self.targets[idx]['image'] | |
# Convert PIL images | |
original = original_img.convert('RGB') | |
target = target_img.convert('RGB') | |
# Extract the filename from the original image's path (assuming it has a 'filename' field or path) | |
original_img_path = self.originals[idx]['image'].filename # Assuming it has this attribute | |
original_img_filename = os.path.basename(original_img_path) | |
# Match the image filename with the `image_path` column in the CSV | |
matched_row = self.data[self.data['image_path'].str.contains(original_img_filename)] | |
if matched_row.empty: | |
raise ValueError(f"No matching entry found in the CSV for image {original_img_filename}") | |
# Get the prompts from the matched row | |
original_prompt = matched_row['original_prompt'].values[0] | |
enhanced_prompt = matched_row['enhanced_prompt'].values[0] | |
# Tokenize the prompts using CLIP tokenizer | |
original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) | |
enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77) | |
# Return transformed images and tokenized prompts | |
return self.transform(original), self.transform(target), original_tokens, enhanced_tokens | |
# Dataset class remains the same | |
class Pix2PixDataset_old(torch.utils.data.Dataset): | |
def __init__(self, ds, transform, csv_path='combined_data.csv'): | |
if not os.path.exists(csv_path): | |
os.system('wget https://huggingface.co/datasets/K00B404/pix2pix_flux_set/resolve/main/combined_data.csv') | |
self.data = pd.read_csv(csv_path) | |
self.clip_tokenizer = clip_tokenizer | |
self.originals = [x for x in ds["train"] if x['label'] == 0] | |
self.targets = [x for x in ds["train"] if x['label'] == 1] | |
assert len(self.originals) == len(self.targets) | |
print(f"Number of original images: {len(self.originals)}") | |
print(f"Number of target images: {len(self.targets)}") | |
self.transform = transform | |
def __len__(self): | |
return len(self.originals) | |
def __getitem__(self, idx): | |
original_img = self.originals[idx]['image'] | |
# TODO: get original_img file name and match with image_path in self.data....then tokenize the prompts with clip_tokenizer | |
target_img = self.targets[idx]['image'] | |
original = original_img.convert('RGB') | |
target = target_img.convert('RGB') | |
return self.transform(original), self.transform(target) | |
class UNetWrapper: | |
def __init__(self, unet_model, repo_id): | |
self.model = unet_model | |
self.repo_id = repo_id | |
self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set | |
self.api = HfApi(token=os.getenv('NEW_TOKEN')) | |
def push_to_hub(self): | |
try: | |
# Save model state and configuration | |
save_dict = { | |
'model_state_dict': self.model.state_dict(), | |
'model_config': { | |
'big': isinstance(self.model, big_UNet), | |
'img_size': 1024 if isinstance(self.model, big_UNet) else 256 | |
}, | |
'model_architecture': str(self.model) | |
} | |
# Save model locally | |
pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth' | |
torch.save(save_dict, pth_name) | |
# Create repo if it doesn't exist | |
try: | |
create_repo( | |
repo_id=self.repo_id, | |
token=self.token, | |
exist_ok=True | |
) | |
except Exception as e: | |
print(f"Repository creation note: {e}") | |
# Upload the model file | |
self.api.upload_file( | |
path_or_fileobj=pth_name, | |
path_in_repo=pth_name, | |
repo_id=self.repo_id, | |
token=self.token, | |
repo_type="model" | |
) | |
# Create and upload model card | |
model_card = f"""--- | |
tags: | |
- unet | |
- pix2pix | |
- pytorch | |
library_name: pytorch | |
license: wtfpl | |
datasets: | |
- K00B404/pix2pix_flux_set | |
language: | |
- en | |
pipeline_tag: image-to-image | |
--- | |
# Pix2Pix UNet Model | |
## Model Description | |
Custom UNet model for Pix2Pix image translation. | |
- **Image Size:** 1024 | |
- **Model Type:** Big (1024) | |
## Usage | |
```python | |
import torch | |
from small_256_model import UNet as small_UNet | |
from big_1024_model import UNet as big_UNet | |
big = True | |
# Load the model | |
name='big_model_weights.pth' if big else 'small_model_weights.pth' | |
checkpoint = torch.load(name) | |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
Model Architecture | |
{str(self.model)} """ | |
# Save and upload README | |
with open("README.md", "w") as f: | |
f.write(model_card) | |
self.api.upload_file( | |
path_or_fileobj="README.md", | |
path_in_repo="README.md", | |
repo_id=self.repo_id, | |
token=self.token, | |
repo_type="model" | |
) | |
# Clean up local files | |
os.remove(pth_name) | |
os.remove("README.md") | |
print(f"Model successfully uploaded to {self.repo_id}") | |
except Exception as e: | |
print(f"Error uploading model: {e}") | |
def prepare_input(image, device='cpu'): | |
"""Prepare image for inference""" | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
return input_tensor | |
def run_inference(image, prompt): | |
"""Run inference on a single image""" | |
global global_model | |
if global_model is None: | |
return "Error: Model not loaded" | |
global_model.eval() | |
input_tensor = prepare_input(image, device) | |
with torch.no_grad(): | |
output = global_model(input_tensor) | |
# Convert output to image | |
output = output.cpu().squeeze(0).permute(1, 2, 0).numpy() | |
output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8) | |
return output | |
def to_hub(model): | |
wrapper = UNetWrapper(model, model_repo_id) | |
wrapper.push_to_hub() | |
def train_model(epochs): | |
"""Training function""" | |
global global_model | |
ds = load_dataset(dataset_id) | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
# Initialize the dataset and dataloader | |
dataset = Pix2PixDataset(ds, transform, clip_tokenizer) | |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
model = global_model | |
criterion = nn.L1Loss() # L1 loss for image reconstruction | |
optimizer = optim.Adam(model.parameters(), lr=LR) | |
output_text = [] | |
for epoch in range(epochs): | |
model.train() | |
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader): | |
# Move images and prompt embeddings to the appropriate device (CPU or GPU) | |
original, target = original.to(device), target.to(device) | |
original_prompt_tokens = original_prompt_tokens.input_ids.to(device) | |
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device) | |
optimizer.zero_grad() | |
# Forward pass through the model | |
output = model(target) | |
# Compute image reconstruction loss | |
img_loss = criterion(output, original) | |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings) | |
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2) | |
# Combine losses | |
total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance | |
total_loss.backward() | |
# Optimizer step | |
optimizer.step() | |
if i % 10 == 0: | |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}" | |
print(status) | |
output_text.append(status) | |
# Push model to Hugging Face Hub at the end of each epoch | |
to_hub(model) | |
global_model = model # Update the global model after training | |
return model, "\n".join(output_text) | |
def gradio_train(epochs): | |
"""Gradio training interface function""" | |
model, training_log = train_model(int(epochs)) | |
to_hub(model) | |
return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}" | |
def gradio_inference(input_image, keywords): | |
"""Gradio inference interface function""" | |
# Generate an enhanced prompt using the chat bot | |
enhanced_prompt = chat_with_bot(keywords) | |
# Run inference on the input image | |
output_image = run_inference(input_image, chat_with_bot(keywords)) | |
return input_image, output_image, keywords, enhanced_prompt | |
def gradio_inference(input_image): | |
"""Gradio inference interface function""" | |
return input_image, run_inference(input_image) | |
# Create Gradio interface with tabs | |
with gr.Blocks() as app: | |
gr.Markdown("# Pix2Pix Model Training and Inference") | |
with gr.Tabs(): | |
with gr.TabItem("Training"): | |
epochs_input = gr.Number(label="Number of Epochs") | |
train_button = gr.Button("Train Model") | |
output_text = gr.Textbox(label="Training Progress", lines=10) | |
train_button.click(gradio_train, inputs=epochs_input, outputs=output_text) | |
with gr.TabItem("Inference"): | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image") | |
output_image = gr.Image(label="Model Output") | |
infer_button = gr.Button("Run Inference") | |
infer_button.click(gradio_inference, inputs=input_image, outputs=[input_image, output_image]) | |
if __name__ == '__main__': | |
# Load model at startup | |
load_model() | |
# Launch the Gradio app | |
app.launch() |