Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import timm | |
| class ImprovedMultiOutputModel(nn.Module): | |
| """Improved multi-output model with EfficientNet backbone.""" | |
| def __init__(self, num_object_classes, num_material_classes, backbone='efficientnet_b0'): | |
| super(ImprovedMultiOutputModel, self).__init__() | |
| # Use EfficientNet backbone | |
| self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0) | |
| backbone_out_features = self.backbone.num_features | |
| # Add attention mechanism | |
| self.attention = nn.Sequential( | |
| nn.Linear(backbone_out_features, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(512, backbone_out_features), | |
| nn.Sigmoid() | |
| ) | |
| # Improved classification heads with dropout and batch norm | |
| self.object_classifier = nn.Sequential( | |
| nn.Linear(backbone_out_features, 1024), | |
| nn.BatchNorm1d(1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(1024, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, num_object_classes) | |
| ) | |
| self.material_classifier = nn.Sequential( | |
| nn.Linear(backbone_out_features, 1024), | |
| nn.BatchNorm1d(1024), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(1024, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, num_material_classes) | |
| ) | |
| def forward(self, x): | |
| # Extract features using backbone | |
| features = self.backbone(x) | |
| # Apply attention mechanism | |
| attention_weights = self.attention(features) | |
| features = features * attention_weights | |
| # Get predictions for each attribute | |
| object_pred = self.object_classifier(features) | |
| material_pred = self.material_classifier(features) | |
| return { | |
| 'object_name': object_pred, | |
| 'material': material_pred, | |
| } | |
| def get_val_transforms(): | |
| """Get transforms for validation.""" | |
| return transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def load_model(model_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| label_mappings = checkpoint['label_mappings'] | |
| num_object_classes = len(label_mappings['object_name']) | |
| num_material_classes = len(label_mappings['material']) | |
| backbone = 'efficientnet_b0' | |
| model = ImprovedMultiOutputModel(num_object_classes, num_material_classes, backbone) | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| model.to(device) | |
| model.eval() | |
| return model, label_mappings | |
| # Load models | |
| models = {} | |
| models['modelv1.pth'], label_mappings_v1 = load_model('modelv1.pth') | |
| models['modelv2.pth'], label_mappings_v2 = load_model('modelv2.pth') | |
| # Assume label_mappings are the same for both, use v1 | |
| label_mappings = label_mappings_v1 | |
| def predict(image, model_choice): | |
| if image is None: | |
| return "Please upload an image." | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = models[model_choice] | |
| transform = get_val_transforms() | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| pred_obj = torch.argmax(outputs['object_name'], dim=1).item() | |
| pred_mat = torch.argmax(outputs['material'], dim=1).item() | |
| # Map IDs back to names | |
| obj_name = [k for k, v in label_mappings['object_name'].items() if v == pred_obj][0] | |
| mat_name = [k for k, v in label_mappings['material'].items() if v == pred_mat][0] | |
| return f"Predicted Object: {obj_name}\nPredicted Material: {mat_name}" | |
| # Create Gradio interface using Blocks | |
| with gr.Blocks(title="Artifact Classification Model") as demo: | |
| gr.Markdown("# Artifact Classification Model") | |
| gr.Markdown("Upload an image to classify the object name and material.") | |
| model_selector = gr.Dropdown(choices=['modelv1.pth', 'modelv2.pth'], label="Select Model", value='modelv1.pth') | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="Upload an Image") | |
| output_text = gr.Textbox(label="Predictions") | |
| predict_btn = gr.Button("Predict") | |
| predict_btn.click(fn=predict, inputs=[input_image, model_selector], outputs=output_text) | |
| if __name__ == "__main__": | |
| demo.launch() |