|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms, models |
|
import gradio as gr |
|
from PIL import Image |
|
|
|
|
|
class ModelRecommender(nn.Module): |
|
def __init__(self, num_models, text_embedding_dim=768): |
|
super(ModelRecommender, self).__init__() |
|
|
|
|
|
self.cnn = models.resnet18(pretrained=True) |
|
self.cnn.fc = nn.Linear(512, 256) |
|
|
|
|
|
self.text_mlp = nn.Sequential( |
|
nn.Linear(text_embedding_dim, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, 256), |
|
nn.ReLU() |
|
) |
|
|
|
|
|
self.combined = nn.Sequential( |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(256, num_models) |
|
) |
|
|
|
def forward(self, image, text_features): |
|
|
|
img_features = self.cnn(image) |
|
|
|
|
|
text_features = self.text_mlp(text_features) |
|
|
|
|
|
combined = torch.cat((img_features, text_features), dim=1) |
|
|
|
|
|
output = self.combined(combined) |
|
return output |
|
|
|
|
|
def load_model(): |
|
|
|
dataset_info = torch.load('dataset_info.pth') |
|
model_names = dataset_info['model_names'] |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = ModelRecommender(len(model_names)) |
|
|
|
|
|
checkpoint = torch.load('sd_recommender_model.pth', map_location=device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(device) |
|
model.eval() |
|
|
|
return model, model_names, device |
|
|
|
|
|
def predict_image(image): |
|
|
|
if not hasattr(predict_image, "model"): |
|
predict_image.model, predict_image.model_names, predict_image.device = load_model() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
image_tensor = transform(image).unsqueeze(0).to(predict_image.device) |
|
dummy_text_features = torch.zeros(1, 768).to(predict_image.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = predict_image.model(image_tensor, dummy_text_features) |
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
top5_prob, top5_indices = torch.topk(probs, 5) |
|
|
|
|
|
results = [] |
|
for prob, idx in zip(top5_prob[0], top5_indices[0]): |
|
model_name = predict_image.model_names[idx.item()] |
|
confidence = f"{prob.item():.2%}" |
|
results.append(f"Model: {model_name}\nConfidence: {confidence}") |
|
|
|
return "\n\n".join(results) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(label="Model Recommendations"), |
|
title="Stable Diffusion Model Recommender", |
|
description="Upload an AI-generated image to get model recommendations", |
|
examples=[["example1.jpg"], ["example2.jpg"]] |
|
) |
|
|
|
demo.launch() |