Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from pathlib import Path | |
| from torchvision import transforms | |
| from src.models.catdog_model_resnet import ResnetClassifier | |
| from src.utils.aws_s3_services import S3Handler | |
| from src.utils.logging_utils import setup_logger | |
| from loguru import logger | |
| import rootutils | |
| # Load environment variables and configure logger | |
| setup_logger(Path("./logs") / "gradio_app.log") | |
| # Setup root directory | |
| root = rootutils.setup_root(__file__, indicator=".project-root") | |
| class ImageClassifier: | |
| def __init__(self, cfg): | |
| self.cfg = cfg | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.classes = cfg.labels | |
| # Download and load model from S3 | |
| logger.info("Downloading model from S3...") | |
| s3_handler = S3Handler(bucket_name="deep-bucket-s3") | |
| s3_handler.download_folder("checkpoints", "checkpoints") | |
| logger.info("Loading model checkpoint...") | |
| self.model = ResnetClassifier.load_from_checkpoint( | |
| checkpoint_path=cfg.ckpt_path | |
| ) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # Image transform | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((cfg.data.image_size, cfg.data.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), | |
| ] | |
| ) | |
| def predict(self, image): | |
| if image is None: | |
| return "No image provided.", None | |
| # Preprocess the image | |
| logger.info("Processing input image...") | |
| img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| output = self.model(img_tensor) | |
| probabilities = F.softmax(output, dim=1) | |
| predicted_class_idx = torch.argmax(probabilities, dim=1).item() | |
| confidence = probabilities[0][predicted_class_idx].item() | |
| predicted_label = self.classes[predicted_class_idx] | |
| logger.info(f"Prediction: {predicted_label} (Confidence: {confidence:.2f})") | |
| return predicted_label, confidence | |
| def create_gradio_app(cfg): | |
| classifier = ImageClassifier(cfg) | |
| def classify_image(image): | |
| """Gradio interface function.""" | |
| predicted_label, confidence = classifier.predict(image) | |
| if predicted_label: | |
| return f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" | |
| return "Error during prediction." | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Cat vs Dog Classifier | |
| Upload an image of a cat or a dog to classify it with confidence. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Input Image", type="pil", image_mode="RGB" | |
| ) | |
| predict_button = gr.Button("Classify") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Prediction") | |
| # Define interaction | |
| predict_button.click( | |
| fn=classify_image, inputs=[input_image], outputs=[output_text] | |
| ) | |
| return demo | |
| # Hydra config wrapper for launching Gradio app | |
| if __name__ == "__main__": | |
| import hydra | |
| from omegaconf import DictConfig | |
| def main(cfg: DictConfig): | |
| logger.info("Launching Gradio App...") | |
| demo = create_gradio_app(cfg) | |
| demo.launch(share=True, server_name="0.0.0.0", server_port=7860) | |
| main() | |