import torch from torchvision import transforms from PIL import Image import gradio as gr from ResNet_for_CC import CC_model # Initialize the model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CC_model() # Load the pre-trained weights, adjusting for DataParallel if necessary model_path = 'CC_net.pt' checkpoint = torch.load(model_path, map_location=device) if any(key.startswith('module.') for key in checkpoint.keys()): checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()} model.load_state_dict(checkpoint) model.eval() model.to(device) # Image preprocessing preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Define class names from category_names_eng.txt class_names = [ 'T-Shirt', 'Shirt', 'Knitwear', 'Chiffon', 'Sweater', 'Hoodie', 'Windbreaker', 'Jacket', 'Downcoat', 'Suit', 'Shawl', 'Dress', 'Vest', 'Underwear' ] def predict(image): # Convert Gradio Image to PIL and preprocess img = Image.fromarray(image.astype('uint8'), 'RGB') img = preprocess(img).unsqueeze(0).to(device) # Generate predictions with torch.no_grad(): dr_feature, output_mean = model(img) # Get the predicted class _, predicted = torch.max(output_mean, 1) predicted_class = class_names[predicted.item()] # Format output return f"Predicted class: {predicted_class}" return f"Class number: {predicted.item()}" # Example images from Hugging Face examples = [ ["example_image(1).JPG"], ["example_image(2).jpg"], ["example_image(3).jpg"], ["example_image(4).webp"], ["example_image(5).webp"], ["example_image(6).webp"] ] # Gradio Interface interface = gr.Interface( fn=predict, inputs=gr.Image(label="Upload Clothing Image"), outputs=gr.Textbox(label="Prediction"), title="Clothing Image Classifier", description="This model classifies clothing images using ResNet50. Try out different examples below for a quick demonstration!", examples=examples ) interface.launch()