Spaces:
Build error
Build error
| from fasthtml.common import * | |
| from fastai.vision.all import * | |
| import os | |
| import time | |
| from pathlib import Path | |
| import urllib.request | |
| from io import BytesIO | |
| # Create necessary directories | |
| os.makedirs('uploads', exist_ok=True) | |
| # Function to load model - with fallback for testing | |
| def load_model(): | |
| try: | |
| model_path = 'levit.pkl' | |
| # Check if model exists, if not try to download a sample model (for demo purposes) | |
| if not os.path.exists(model_path): | |
| print("Model not found. This is just for testing purposes.") | |
| # In a real deployment, you'd want to handle this more gracefully | |
| return None, ['class1', 'class2', 'class3'] | |
| learn = load_learner(model_path) | |
| labels = learn.dls.vocab | |
| print(f"Model loaded successfully with labels: {labels}") | |
| return learn, labels | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Fallback for testing | |
| return None, ['class1', 'class2', 'class3'] | |
| # Load the model at startup | |
| learn, labels = load_model() | |
| # Create a FastHTML app | |
| app, rt = fast_app() | |
| # Define the prediction function | |
| def predict(img_bytes): | |
| try: | |
| # If no model is loaded, return mock predictions for testing | |
| if learn is None: | |
| import random | |
| mock_results = {label: random.random() for label in labels} | |
| # Sort by values and normalize to ensure they sum to 1 | |
| total = sum(mock_results.values()) | |
| return {k: v/total for k, v in sorted(mock_results.items(), key=lambda x: x[1], reverse=True)} | |
| # Real prediction with the model | |
| img = PILImage.create(BytesIO(img_bytes)) | |
| img = img.resize((512, 512)) | |
| pred, pred_idx, probs = learn.predict(img) | |
| return {labels[i]: float(probs[i]) for i in range(len(labels))} | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| return {"Error": 1.0} | |
| # Main page route | |
| def get(): | |
| # Create a form for image upload | |
| upload_form = Form( | |
| Div( | |
| H1("FastAI Image Classifier"), | |
| P("Upload an image to classify it using a pre-trained model."), | |
| cls="instructions" | |
| ), | |
| Div( | |
| Input(type="file", name="image", accept="image/*", required=True, | |
| hx_indicator="#loading"), | |
| Button("Classify", type="submit"), | |
| cls="upload-controls" | |
| ), | |
| hx_post="/predict", | |
| hx_target="#result", | |
| hx_swap="innerHTML", | |
| hx_encoding="multipart/form-data", | |
| id="upload-form" | |
| ) | |
| # Add loading indicator | |
| loading = Div( | |
| P("Processing your image..."), | |
| id="loading", | |
| cls="htmx-indicator" | |
| ) | |
| # Container for results | |
| result_container = Div(id="result", cls="result-container") | |
| # Example section | |
| examples = Div( | |
| H2("Or try an example:"), | |
| A("Example Image", href="#", | |
| hx_get="/predict_example", | |
| hx_target="#result", | |
| hx_indicator="#loading"), | |
| cls="examples-section" | |
| ) | |
| # CSS styles | |
| css = """ | |
| :root { | |
| --primary-color: #3498db; | |
| --secondary-color: #2c3e50; | |
| --background-color: #f9f9f9; | |
| --error-color: #e74c3c; | |
| --shadow-color: rgba(0, 0, 0, 0.1); | |
| --border-color: #ddd; | |
| } | |
| body { | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif; | |
| line-height: 1.6; | |
| color: #333; | |
| max-width: 800px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| background-color: #fff; | |
| } | |
| h1 { | |
| color: var(--secondary-color); | |
| margin-bottom: 1rem; | |
| font-weight: 600; | |
| } | |
| h2 { | |
| color: var(--primary-color); | |
| margin-top: 1.5rem; | |
| font-weight: 500; | |
| } | |
| .instructions { | |
| margin-bottom: 20px; | |
| } | |
| .upload-controls { | |
| display: flex; | |
| gap: 10px; | |
| margin-bottom: 30px; | |
| align-items: center; | |
| flex-wrap: wrap; | |
| } | |
| button { | |
| background-color: var(--primary-color); | |
| color: white; | |
| border: none; | |
| padding: 10px 15px; | |
| border-radius: 4px; | |
| cursor: pointer; | |
| transition: background-color 0.3s; | |
| font-weight: 500; | |
| } | |
| button:hover { | |
| background-color: #2980b9; | |
| } | |
| input[type="file"] { | |
| padding: 10px; | |
| border: 1px solid var(--border-color); | |
| border-radius: 4px; | |
| flex-grow: 1; | |
| } | |
| #upload-form { | |
| margin-bottom: 40px; | |
| padding: 20px; | |
| border-radius: 8px; | |
| background-color: var(--background-color); | |
| box-shadow: 0 2px 10px var(--shadow-color); | |
| } | |
| .result-container { | |
| margin-top: 20px; | |
| } | |
| .prediction-results { | |
| margin-top: 20px; | |
| padding: 20px; | |
| border: 1px solid var(--border-color); | |
| border-radius: 8px; | |
| background-color: var(--background-color); | |
| box-shadow: 0 2px 8px var(--shadow-color); | |
| } | |
| .result-image { | |
| max-width: 100%; | |
| height: auto; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 5px var(--shadow-color); | |
| margin-bottom: 20px; | |
| display: block; | |
| } | |
| .prediction-list { | |
| margin-top: 15px; | |
| } | |
| .prediction-item { | |
| padding: 12px 15px; | |
| margin-bottom: 10px; | |
| background-color: white; | |
| border-radius: 6px; | |
| box-shadow: 0 1px 3px var(--shadow-color); | |
| } | |
| .label-text { | |
| margin-bottom: 8px; | |
| font-weight: 500; | |
| display: flex; | |
| justify-content: space-between; | |
| } | |
| .examples-section { | |
| margin-top: 30px; | |
| padding-top: 20px; | |
| border-top: 1px solid var(--border-color); | |
| } | |
| .htmx-indicator { | |
| display: none; | |
| padding: 15px; | |
| background-color: #e8f4fc; | |
| border-radius: 6px; | |
| text-align: center; | |
| margin: 15px 0; | |
| box-shadow: 0 1px 3px var(--shadow-color); | |
| } | |
| .htmx-request .htmx-indicator { | |
| display: block; | |
| } | |
| .progress-bar { | |
| height: 10px; | |
| background-color: #f0f0f0; | |
| border-radius: 5px; | |
| margin: 5px 0; | |
| overflow: hidden; | |
| } | |
| .progress-fill { | |
| height: 100%; | |
| background-color: var(--primary-color); | |
| width: 0; | |
| transition: width 0.5s ease; | |
| } | |
| .error-message { | |
| color: var(--error-color); | |
| padding: 15px; | |
| border: 1px solid var(--error-color); | |
| border-radius: 5px; | |
| background-color: #fde9e7; | |
| } | |
| a { | |
| color: var(--primary-color); | |
| text-decoration: none; | |
| font-weight: 500; | |
| } | |
| a:hover { | |
| text-decoration: underline; | |
| } | |
| /* Responsive styling */ | |
| @media (max-width: 600px) { | |
| .upload-controls { | |
| flex-direction: column; | |
| align-items: stretch; | |
| } | |
| button { | |
| width: 100%; | |
| } | |
| } | |
| .model-info { | |
| font-size: 0.9rem; | |
| color: #666; | |
| margin-top: 40px; | |
| padding-top: 20px; | |
| border-top: 1px solid var(--border-color); | |
| } | |
| """ | |
| # Model information | |
| model_info = Div( | |
| P(f"Model: {'Model loaded successfully' if learn is not None else 'Demo mode - no model loaded'}"), | |
| P(f"Classes: {', '.join(labels)}"), | |
| cls="model-info" | |
| ) | |
| return Titled("FastAI Image Classifier", | |
| upload_form, | |
| loading, | |
| result_container, | |
| examples, | |
| model_info, | |
| Style(css)) | |
| # Prediction route for uploaded images | |
| async def post(image: UploadFile): | |
| try: | |
| # Read the uploaded image | |
| image_bytes = await image.read() | |
| # Generate a unique filename to avoid conflicts | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_filename = f"{timestamp}_{image.filename.replace(' ', '_')}" | |
| # Save the image temporarily | |
| img_path = f"uploads/{safe_filename}" | |
| with open(img_path, "wb") as f: | |
| f.write(image_bytes) | |
| # Add a small delay to make the loading indicator visible | |
| time.sleep(0.5) | |
| # Make a prediction | |
| results = predict(image_bytes) | |
| # Sort results by probability | |
| sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
| top_results = dict(list(sorted_results.items())[:3]) | |
| # Create prediction items with progress bars | |
| prediction_items = [] | |
| for label, prob in top_results.items(): | |
| percentage = int(prob * 100) | |
| prediction_items.append( | |
| Div( | |
| Div( | |
| Span(f"{label}"), | |
| Span(f"{percentage}%"), | |
| cls="label-text" | |
| ), | |
| Div( | |
| Div(cls="progress-fill", style=f"width: {percentage}%;"), | |
| cls="progress-bar" | |
| ), | |
| cls="prediction-item" | |
| ) | |
| ) | |
| # Create result HTML | |
| result_html = Div( | |
| H2("Prediction Results:"), | |
| Img(src=f"/image/{safe_filename}", cls="result-image", alt="Uploaded image"), | |
| Div(*prediction_items, cls="prediction-list"), | |
| cls="prediction-results" | |
| ) | |
| return result_html | |
| except Exception as e: | |
| return Div( | |
| H2("Error"), | |
| P(f"An error occurred during prediction: {str(e)}"), | |
| cls="error-message" | |
| ) | |
| # Route to serve saved images | |
| def get(filename: str): | |
| file_path = f"uploads/{filename}" | |
| if os.path.exists(file_path): | |
| return FileResponse(file_path) | |
| else: | |
| return Div( | |
| H2("Error"), | |
| P("Image not found."), | |
| cls="error-message" | |
| ) | |
| # Route for example image | |
| def get(): | |
| try: | |
| # Path to example image | |
| example_path = "image.jpg" | |
| # Check if example image exists | |
| if os.path.exists(example_path): | |
| with open(example_path, "rb") as f: | |
| image_bytes = f.read() | |
| # Save the example image to uploads | |
| example_name = "example.jpg" | |
| with open(f"uploads/{example_name}", "wb") as f: | |
| f.write(image_bytes) | |
| # Add a small delay to make the loading indicator visible | |
| time.sleep(0.5) | |
| # Make a prediction | |
| results = predict(image_bytes) | |
| # Sort results by probability | |
| sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
| top_results = dict(list(sorted_results.items())[:3]) | |
| # Create prediction items with progress bars | |
| prediction_items = [] | |
| for label, prob in top_results.items(): | |
| percentage = int(prob * 100) | |
| prediction_items.append( | |
| Div( | |
| Div( | |
| Span(f"{label}"), | |
| Span(f"{percentage}%"), | |
| cls="label-text" | |
| ), | |
| Div( | |
| Div(cls="progress-fill", style=f"width: {percentage}%;"), | |
| cls="progress-bar" | |
| ), | |
| cls="prediction-item" | |
| ) | |
| ) | |
| # Create result HTML | |
| result_html = Div( | |
| H2("Prediction Results:"), | |
| Img(src=f"/image/{example_name}", cls="result-image", alt="Example image"), | |
| Div(*prediction_items, cls="prediction-list"), | |
| P("This is a demonstration using the provided example image.", style="font-style: italic; color: #666;"), | |
| cls="prediction-results" | |
| ) | |
| return result_html | |
| else: | |
| return Div( | |
| H2("Example Not Found"), | |
| P("The example image 'image.jpg' was not found. Please try uploading your own image."), | |
| cls="error-message" | |
| ) | |
| except Exception as e: | |
| return Div( | |
| H2("Error"), | |
| P(f"An error occurred with the example: {str(e)}"), | |
| cls="error-message" | |
| ) | |
| # Health check endpoint (useful for Docker/Kubernetes) | |
| def get(): | |
| return {"status": "ok", "model_loaded": learn is not None} | |
| # Run the app | |
| if __name__ == "__main__": | |
| # Use environment variables if available (common in Docker) | |
| host = os.environ.get("HOST", "0.0.0.0") | |
| port = int(os.environ.get("PORT", 8000)) | |
| print(f"Starting FastHTML server on {host}:{port}") | |
| serve(app=app, host=host, port=port) |