CNN_MLP_2 / app.py
bgaspra's picture
Create app.py
d46f971 verified
raw
history blame
3.39 kB
import torch
import torch.nn as nn
from torchvision import transforms, models
import gradio as gr
from PIL import Image
# Model Architecture (sama seperti sebelumnya)
class ModelRecommender(nn.Module):
def __init__(self, num_models, text_embedding_dim=768):
super(ModelRecommender, self).__init__()
# CNN for image processing
self.cnn = models.resnet18(pretrained=True)
self.cnn.fc = nn.Linear(512, 256)
# MLP for text processing
self.text_mlp = nn.Sequential(
nn.Linear(text_embedding_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU()
)
# Combined layers
self.combined = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_models)
)
def forward(self, image, text_features):
# Process image
img_features = self.cnn(image)
# Process text
text_features = self.text_mlp(text_features)
# Combine features
combined = torch.cat((img_features, text_features), dim=1)
# Final prediction
output = self.combined(combined)
return output
# Load model dan dataset info
def load_model():
# Load dataset info
dataset_info = torch.load('dataset_info.pth')
model_names = dataset_info['model_names']
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModelRecommender(len(model_names))
# Load model weights
checkpoint = torch.load('sd_recommender_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, model_names, device
# Inference function
def predict_image(image):
# Load model if not loaded
if not hasattr(predict_image, "model"):
predict_image.model, predict_image.model_names, predict_image.device = load_model()
# Preprocess image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(predict_image.device)
dummy_text_features = torch.zeros(1, 768).to(predict_image.device)
# Get predictions
with torch.no_grad():
outputs = predict_image.model(image_tensor, dummy_text_features)
probs = torch.nn.functional.softmax(outputs, dim=1)
top5_prob, top5_indices = torch.topk(probs, 5)
# Format results
results = []
for prob, idx in zip(top5_prob[0], top5_indices[0]):
model_name = predict_image.model_names[idx.item()]
confidence = f"{prob.item():.2%}"
results.append(f"Model: {model_name}\nConfidence: {confidence}")
return "\n\n".join(results)
# Gradio Interface
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Model Recommendations"),
title="Stable Diffusion Model Recommender",
description="Upload an AI-generated image to get model recommendations",
examples=[["example1.jpg"], ["example2.jpg"]] # Tambahkan contoh gambar jika ada
)
demo.launch()