from transformers import AutoModelForImageClassification, AutoFeatureExtractor import torch from PIL import Image import gradio as gr def load_model_from_hub(repo_id: str): """ Load model from Hugging Face Hub Args: repo_id: The repository ID (e.g., 'username/model-name') Returns: model: The loaded model processor: The feature extractor/processor """ # Load model and processor from Hub model = AutoModelForImageClassification.from_pretrained(repo_id) processor = AutoFeatureExtractor.from_pretrained(repo_id) return model, processor def predict(image_path: str, model, processor): """ Make prediction using the loaded model Args: image_path: Path to input image model: Loaded model processor: Feature extractor/processor Returns: prediction: Model prediction """ # Load and preprocess image image = Image.open(image_path) inputs = processor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**inputs) predictions = outputs.logits.softmax(-1) return predictions def predict_image(image): """ Gradio interface function for prediction Args: image: Image uploaded through Gradio interface Returns: str: Prediction result with confidence score """ # Convert from numpy array to PIL Image if not isinstance(image, Image.Image): image = Image.fromarray(image) # Process image and get prediction inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predictions = outputs.logits.softmax(-1) # Get the top prediction pred_scores = predictions[0].tolist() top_pred_idx = max(range(len(pred_scores)), key=pred_scores.__getitem__) confidence = pred_scores[top_pred_idx] # Get class label if hasattr(model.config, 'id2label'): label = model.config.id2label[top_pred_idx] else: label = f"Class {top_pred_idx}" return f"{label} (Confidence: {confidence:.2%})" # Load model at startup model, processor = load_model_from_hub("srtangirala/resnet50-exp") # Create Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(), outputs=gr.Text(), title="Image Classification", description="Upload an image to classify it!", examples=[ # You can add example images here # ["path/to/example1.jpg"], # ["path/to/example2.jpg"] ] ) if __name__ == "__main__": iface.launch()