Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| # ββ Fix Matplotlib cache permission errors ββ | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" | |
| # ββ Constants and Paths ββ | |
| MODEL_PATH = "./mobilenetv2.pth" | |
| CLASS_NAMES = ["undercooked", "raw", "cooked"] | |
| IMAGE_SIZE = (224, 224) | |
| # ββ Load model ββ | |
| def load_model(): | |
| model = models.mobilenet_v2(weights=None) | |
| model.classifier[1] = nn.Linear(model.last_channel, len(CLASS_NAMES)) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ββ Image transform ββ | |
| transform = transforms.Compose([ | |
| transforms.Resize(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| ]) | |
| def classify(image, progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| image = image.convert("RGB") | |
| tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(tensor) | |
| prediction = torch.argmax(outputs, dim=1).item() | |
| return {CLASS_NAMES[prediction]: 1.0} | |
| except Exception as e: | |
| print("Error during prediction:", e) | |
| # Return fixed label with zero confidence or default fallback | |
| return {"Error": 0.0} | |
| # ββ Gradio Layout ββ | |
| with gr.Blocks(css="#main-col {max-width: 640px; margin: auto;}") as demo: | |
| with gr.Column(elem_id="main-col"): | |
| gr.Markdown("## π³ MobileNetV2 Food Doneness Classifier") | |
| gr.Markdown("Upload an image of food to determine if it's **undercooked**, **raw**, or **cooked**.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| run_button = gr.Button("Classify", variant="primary") | |
| result_output = gr.Label(label="Prediction") | |
| gr.HTML("<br><small>Custom model trained with MobileNetV2</small>") | |
| run_button.click(fn=classify, inputs=image_input, outputs=result_output) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |