Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from torchvision import models | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader, Dataset | |
| from sklearn.preprocessing import LabelEncoder | |
| # Load dataset | |
| dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') | |
| # Text preprocessing function with None handling | |
| def preprocess_text(text, max_length=100): | |
| # Handle None or empty text | |
| if text is None or not isinstance(text, str): | |
| text = "" | |
| # Convert text to lowercase and split into words | |
| words = text.lower().split() | |
| # Truncate or pad to max_length | |
| if len(words) > max_length: | |
| words = words[:max_length] | |
| else: | |
| words.extend([''] * (max_length - len(words))) | |
| return words | |
| class CustomDataset(Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Filter out None values from Model column | |
| valid_indices = [i for i, model in enumerate(dataset['Model']) if model is not None] | |
| self.valid_dataset = dataset.select(valid_indices) | |
| self.label_encoder = LabelEncoder() | |
| self.labels = self.label_encoder.fit_transform(self.valid_dataset['Model']) | |
| # Create vocabulary from all prompts | |
| self.vocab = set() | |
| for item in self.valid_dataset['prompt']: | |
| try: | |
| self.vocab.update(preprocess_text(item)) | |
| except Exception as e: | |
| print(f"Error processing prompt: {e}") | |
| continue | |
| # Remove empty string from vocabulary if present | |
| self.vocab.discard('') | |
| self.vocab = list(self.vocab) | |
| self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} | |
| def __len__(self): | |
| return len(self.valid_dataset) | |
| def text_to_vector(self, text): | |
| try: | |
| words = preprocess_text(text) | |
| vector = torch.zeros(len(self.vocab)) | |
| for word in words: | |
| if word in self.word_to_idx: | |
| vector[self.word_to_idx[word]] += 1 | |
| return vector | |
| except Exception as e: | |
| print(f"Error converting text to vector: {e}") | |
| return torch.zeros(len(self.vocab)) | |
| def __getitem__(self, idx): | |
| try: | |
| image = self.transform(self.valid_dataset[idx]['image']) | |
| text_vector = self.text_to_vector(self.valid_dataset[idx]['prompt']) | |
| label = self.labels[idx] | |
| return image, text_vector, label | |
| except Exception as e: | |
| print(f"Error getting item at index {idx}: {e}") | |
| # Return zero tensors as fallback | |
| return (torch.zeros((3, 224, 224)), | |
| torch.zeros(len(self.vocab)), | |
| 0) | |
| # Define CNN for image processing | |
| class ImageModel(nn.Module): | |
| def __init__(self): | |
| super(ImageModel, self).__init__() | |
| self.model = models.resnet18(pretrained=True) | |
| self.model.fc = nn.Linear(self.model.fc.in_features, 512) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Define MLP for text processing | |
| class TextMLP(nn.Module): | |
| def __init__(self, vocab_size): | |
| super(TextMLP, self).__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(vocab_size, 1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, 512) | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| # Combined model | |
| class CombinedModel(nn.Module): | |
| def __init__(self, vocab_size, num_classes): | |
| super(CombinedModel, self).__init__() | |
| self.image_model = ImageModel() | |
| self.text_model = TextMLP(vocab_size) | |
| self.fc = nn.Linear(1024, num_classes) | |
| def forward(self, image, text): | |
| image_features = self.image_model(image) | |
| text_features = self.text_model(text) | |
| combined = torch.cat((image_features, text_features), dim=1) | |
| return self.fc(combined) | |
| # Create dataset instance | |
| print("Creating dataset...") | |
| custom_dataset = CustomDataset(dataset) | |
| print(f"Vocabulary size: {len(custom_dataset.vocab)}") | |
| print(f"Number of valid samples: {len(custom_dataset)}") | |
| # Create model | |
| num_classes = len(custom_dataset.label_encoder.classes_) | |
| model = CombinedModel(len(custom_dataset.vocab), num_classes) | |
| def get_recommendations(image): | |
| model.eval() | |
| with torch.no_grad(): | |
| # Process input image | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Create dummy text vector | |
| dummy_text = torch.zeros((1, len(custom_dataset.vocab))) | |
| # Get model output | |
| output = model(image_tensor, dummy_text) | |
| _, indices = torch.topk(output, 5) | |
| # Get recommended images and their information | |
| recommendations = [] | |
| for idx in indices[0]: | |
| try: | |
| recommended_image = custom_dataset.valid_dataset[idx.item()]['image'] | |
| model_name = custom_dataset.valid_dataset[idx.item()]['Model'] | |
| recommendations.append((recommended_image, f"{model_name}")) | |
| except Exception as e: | |
| print(f"Error getting recommendation for index {idx}: {e}") | |
| continue | |
| return recommendations | |
| # Set up Gradio interface | |
| interface = gr.Interface( | |
| fn=get_recommendations, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Gallery(label="Recommended Images"), | |
| title="Image Recommendation System", | |
| description="Upload an image and get similar images with their model names." | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() |