import torch from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig import gradio as gr from PIL import Image import os import logging # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Define the class labels in the correct order as used during training labels = ['Leggings', 'Jogger', 'Palazzo', 'Cargo', 'Dresspants', 'Chinos'] logging.info(f"Labels: {labels}") # Define the path to the uploaded model file model_path = "best_fine_tuned_vit_Leggings_Jogger_Palazzo_Cargo_Dresspants_Chinos_93.90243902439025_2024-08-26.pth" logging.info(f"Looking for model file: {model_path}") if os.path.exists(model_path): logging.info(f"Model file found: {model_path}") else: logging.error(f"Model file not found: {model_path}") raise FileNotFoundError(f"Model file not found: {model_path}") # Create label mappings consistent with training id2label = {str(i): label for i, label in enumerate(labels)} label2id = {label: str(i) for i, label in enumerate(labels)} # Create a configuration for the model config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k") config.num_labels = len(labels) config.id2label = id2label config.label2id = label2id # Initialize the model with the configuration model = ViTForImageClassification(config) try: # Load the state dict of the fine-tuned model state_dict = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(state_dict) logging.info("Fine-tuned model loaded successfully") except Exception as e: logging.error(f"Error loading model: {str(e)}") raise model.eval() logging.info("Model set to evaluation mode") # Load feature extractor feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") logging.info("Feature extractor loaded") # Define the prediction function def predict(image): logging.info("Starting prediction") logging.info(f"Input image shape: {image.size}") # Preprocess the image logging.info("Preprocessing image") inputs = feature_extractor(images=image, return_tensors="pt") logging.info(f"Preprocessed input shape: {inputs['pixel_values'].shape}") logging.info("Running inference") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits[0], dim=0) logging.info(f"Raw logits: {logits}") logging.info(f"Probabilities: {probabilities}") # Prepare the output dictionary result = {labels[i]: float(probabilities[i]) for i in range(len(labels))} logging.info(f"Prediction result: {result}") return result # Set up the Gradio Interface logging.info("Setting up Gradio interface") gradio_app = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=6), title="Pants Shape Classifier" ) # Launch the app if __name__ == "__main__": logging.info("Launching the app") gradio_app.launch()