Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from ResNet_for_CC import CC_model # Ensure correct model import | |
| # Set device (CPU/GPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the trained CC_model | |
| model_path = "CC_net.pt" # Update path if necessary | |
| model = CC_model(num_classes1=14) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Define Clothing1M Class Labels | |
| class_labels = [ | |
| "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", | |
| "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", | |
| "Vest", "Underwear" | |
| ] | |
| # Preprocess images | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Inference function with confidence scores | |
| def classify_image(image): | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| probabilities = F.softmax(output, dim=1).cpu().numpy()[0] | |
| predicted_idx = probabilities.argmax() | |
| predicted_label = class_labels[predicted_idx] | |
| confidence = probabilities[predicted_idx] | |
| # Prepare a readable confidence interval | |
| confidence_pct = round(confidence * 100, 2) | |
| result = f"Predicted Class: {predicted_label}\nConfidence: {confidence_pct}%" | |
| return result | |
| # Example images for the Gradio Interface (upload these images to your Hugging Face Space) | |
| example_images = [ | |
| "img1.png", | |
| "img2.png", | |
| "img3.png", | |
| "img4.png", | |
| "img5.png" | |
| ] | |
| # Gradio Interface including confidence intervals | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(label="Prediction and Confidence"), | |
| title="Clothing1M Image Classifier with Confidence Interval", | |
| description="Upload an image or select from examples to classify it and view the confidence percentage.", | |
| examples=example_images, | |
| cache_examples=False | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch() | |