Spaces:
Sleeping
Sleeping
from fastai.vision.all import * | |
from fasthtml.common import * | |
from PIL import Image | |
from io import BytesIO | |
from starlette.responses import JSONResponse | |
from fasthtml_hf import setup_hf_backup | |
import traceback | |
app, rt = fast_app(pico=False, hdrs=( | |
Script(src="https://cdn.tailwindcss.com?plugins=forms,typography"), | |
# Add custom styles | |
Style(""" | |
body { background-color: #f9fafb; } | |
.card-shadow { box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); } | |
.gradient-bg { background: linear-gradient(135deg, #3b82f6 0%, #1e40af 100%); } | |
""") | |
)) | |
# JavaScript remains unchanged | |
drag_drop_js = """ | |
function setupDragAndDrop() { | |
const dropZone = document.getElementById('drop-zone'); | |
const previewImg = document.getElementById('preview-img'); | |
const filenameInput = document.getElementById('selected-filename'); | |
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(e => | |
dropZone.addEventListener(e, e => { e.preventDefault(); e.stopPropagation(); }, false)); | |
['dragenter', 'dragover'].forEach(e => | |
dropZone.addEventListener(e, () => dropZone.classList.add('border-blue-500', 'bg-blue-50'), false)); | |
['dragleave', 'drop'].forEach(e => | |
dropZone.addEventListener(e, () => dropZone.classList.remove('border-blue-500', 'bg-blue-50'), false)); | |
dropZone.addEventListener('drop', e => { | |
if (e.dataTransfer.files.length) { | |
const file = e.dataTransfer.files[0]; | |
const formData = new FormData(); | |
formData.append('file', file); | |
fetch('/upload', { method: 'POST', body: formData }) | |
.then(response => response.json()) | |
.then(data => { | |
previewImg.src = '/drop.jpg?' + new Date().getTime(); | |
filenameInput.value = 'drop.jpg'; | |
}) | |
.catch(error => console.error('Error:', error)); | |
} | |
}, false); | |
} | |
document.addEventListener('DOMContentLoaded', setupDragAndDrop); | |
""" | |
# Enhanced image item component | |
def image_item(filename, color, label): | |
return Div( | |
Div( | |
Img(src=f"/{filename}", | |
cls=f"w-full h-auto cursor-pointer rounded-lg border-[5px] border-{color}", | |
onclick=f"document.getElementById('preview-img').src='/{filename}'; document.getElementById('selected-filename').value='{filename}';"), | |
cls="overflow-hidden" | |
), | |
P(f"{label} Labrador", cls="text-sm font-medium text-gray-700 mt-2 text-center"), | |
cls="w-[180px]" # Increased width from 130px to 180px | |
) | |
def get(): | |
return Div( | |
# Header with gradient background, logo, and Twitter link | |
Div( | |
Div( | |
# Flex container for all header elements with space-between | |
Div( | |
# Left side with logo and text | |
Div( | |
# Logo | |
Img(src="/lab-logo.png", alt="Labrador Classifier Logo", | |
cls="h-20 w-auto mr-4"), | |
# Text content | |
Div( | |
H1("Labrador Classifier", cls="text-2xl font-bold text-white m-0"), | |
P("Identify the type of Labrador using AI", cls="text-blue-100 m-0"), | |
), | |
# Make this a flex container to align logo and text | |
cls="flex items-center" | |
), | |
# Right side with Twitter link | |
A( | |
Img(src="/logo-white.png", alt="Twitter", | |
cls="h-6 w-auto transition-transform hover:scale-110"), | |
href="https://x.com/dgwyer", | |
title="Follow me for more AI content!", | |
target="_blank", | |
rel="noopener noreferrer", | |
cls="flex items-center" | |
), | |
# Flex container properties to push items to opposite ends | |
cls="flex justify-between items-center" | |
), | |
cls="max-w-6xl mx-auto px-4 py-6" | |
), | |
cls="gradient-bg w-full mb-8" | |
), | |
# Main content container | |
Div( | |
# Left column | |
Div( | |
# Selected image section | |
Div( | |
H2("Image Analysis", cls="text-xl font-semibold text-gray-800 mb-4 pb-2 border-b"), | |
# Drop zone and preview | |
Div( | |
Div( | |
Img(id="preview-img", src="/black.jpg", | |
cls="w-full h-auto object-contain rounded-lg mb-4 mx-auto block min-h-[200px] max-h-[200px]"), | |
P("Drag & Drop Image Here", | |
cls="text-gray-500 text-sm absolute bottom-4 left-0 right-0 text-center bg-white bg-opacity-75 py-2"), | |
id="drop-zone", | |
cls="relative w-full border-2 border-dashed border-blue-300 p-8 rounded-xl text-center cursor-pointer transition-colors bg-blue-50 bg-opacity-50 mb-4 hover:bg-blue-100 hover:border-blue-400" # Increased padding to p-8 | |
), | |
# Prediction form | |
Form( | |
Input(type="hidden", id="selected-filename", name="filename", value="black.jpg"), | |
Button('Analyze Image', type="submit", | |
cls="w-full py-3 px-4 rounded-lg bg-blue-600 text-white font-medium shadow-md hover:bg-blue-700 transition-colors focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50"), | |
hx_post="/loading", hx_target="#predictions", cls="mt-3" | |
), | |
cls="w-full max-w-lg mx-auto" # Changed to max-width instead of percentage | |
), | |
cls="bg-white rounded-xl p-6 mb-6 card-shadow" | |
), | |
# Sample images section | |
Div( | |
H2("Sample Images", cls="text-xl font-semibold text-gray-800 mb-4 pb-2 border-b"), | |
P("Click an image to analyze", cls="text-gray-600 mb-4"), | |
Div( | |
image_item("black.jpg", "gray-800", "Black"), | |
image_item("yellow.jpg", "yellow-500", "Yellow"), | |
image_item("chocolate.jpg", "amber-700", "Chocolate"), | |
cls="flex flex-row justify-center gap-8 mx-auto" # Increased gap from 6 to 8 | |
), | |
cls="bg-white rounded-xl p-6 card-shadow" | |
), | |
cls="w-full lg:w-2/3 pr-0 lg:pr-8" # Increased right padding | |
), | |
# Right column | |
Div( | |
Div( | |
H2('Results', cls="text-xl font-semibold text-gray-800 mb-4 pb-2 border-b"), | |
Div( | |
Div( | |
NotStr('<svg class="w-12 h-12 text-blue-500 mx-auto mb-4" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9.663 17h4.673M12 3v1m6.364 1.636l-.707.707M21 12h-1M4 12H3m3.343-5.657l-.707-.707m2.828 9.9a5 5 0 117.072 0l-.548.547A3.374 3.374 0 0014 18.469V19a2 2 0 11-4 0v-.531c0-.895-.356-1.754-.988-2.386l-.548-.547z"></path></svg>'), | |
H3("Ready for Analysis", cls="text-lg font-medium text-gray-700 text-center"), | |
P('Click the "Analyze Image" button to identify the type of Labrador.', | |
cls="text-gray-600 text-center"), | |
cls="py-8" | |
), | |
id="predictions", | |
cls="bg-gray-50 rounded-lg p-4 min-h-[250px] flex items-center" | |
), | |
cls="bg-white rounded-xl p-6 card-shadow sticky top-6" | |
), | |
cls="w-full lg:w-1/3 mt-6 lg:mt-0" | |
), | |
cls="flex flex-col lg:flex-row gap-6 max-w-6xl mx-auto px-4" # Added max-width and padding | |
), | |
# Footer | |
Div( | |
P("Β© 2025 Labrador Classifier β’ By David Gwyer β’ Powered by FastAI and FastHTML", | |
cls="text-center text-gray-500 text-sm"), | |
cls="mt-12 py-6 border-t max-w-6xl mx-auto px-4" # Added max-width and padding | |
), | |
Script(drag_drop_js), | |
cls="min-h-screen" | |
) | |
async def post(file: UploadFile): | |
img = Image.open(BytesIO(await file.read())).resize((128, 128), Image.LANCZOS) | |
img.save("drop.jpg") | |
return JSONResponse({"success": True, "filename": "drop.jpg"}) | |
def post(filename: str = "black.jpg"): | |
return Div( | |
Div( | |
Div( | |
NotStr('<svg class="w-12 h-12 animate-spin text-blue-600 mx-auto mb-4" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'), | |
H3("Analyzing Image", cls="text-lg font-medium text-gray-700 text-center"), | |
P("Please wait while we process your image...", cls="text-gray-600 text-center"), | |
cls="py-8" | |
), | |
cls="flex items-center justify-center" | |
), | |
cls="bg-gray-50 rounded-lg p-4 min-h-[250px]", | |
hx_get=f"/process?filename={filename}", | |
hx_trigger="load", | |
hx_swap="outerHTML" | |
) | |
def get(filename: str = "black.jpg"): | |
labrador_learner = load_learner( 'export.pkl', cpu=True) | |
prediction = labrador_learner.predict(filename) | |
# Extract prediction data | |
label, class_idx, probabilities = prediction | |
confidence = probabilities[class_idx.item()].item() * 100 | |
# Determine which prediction image to show based on the label | |
pred_image = "" | |
if label == "black": | |
pred_image = "bl-pred.png" | |
elif label == "yellow": | |
pred_image = "yl-pred.png" | |
elif label == "chocolate": | |
pred_image = "cl-pred.png" | |
return Div( | |
Div( | |
# Success icon | |
NotStr('<svg class="w-12 h-12 text-green-500 mx-auto mb-2" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"></path></svg>'), | |
# Results | |
H3('Analysis Complete', cls="text-lg font-medium text-gray-700 text-center mb-2"), | |
# Prediction image | |
Div( | |
Img(src=f"/{pred_image}", alt=f"{label.capitalize()} Labrador", | |
cls="w-32 h-32 mx-auto mb-4 object-contain"), | |
cls="text-center" | |
), | |
# Prediction card | |
Div( | |
Div( | |
Div( | |
P("Prediction", cls="text-xs font-medium text-gray-500 uppercase tracking-wide"), | |
P(f'{label.capitalize()} Labrador', | |
cls="text-lg font-bold text-blue-600"), | |
cls="flex-grow" | |
), | |
Div( | |
P("Confidence", cls="text-xs font-medium text-gray-500 uppercase tracking-wide"), | |
P(f'{confidence:.1f}%', | |
cls="text-lg font-bold text-gray-800"), | |
cls="text-right" | |
), | |
cls="flex justify-between items-center" | |
), | |
# Progress bar for confidence | |
Div( | |
Div( | |
cls=f"h-2 bg-blue-600 rounded-full", | |
style=f"width: {confidence}%" | |
), | |
cls="w-full bg-gray-200 rounded-full h-2 mt-2" | |
), | |
cls="bg-white rounded-lg p-4 shadow-sm border border-gray-200" | |
) | |
), | |
cls="bg-gray-50 rounded-lg p-6 min-h-[250px]" | |
) | |
def get(): | |
try: | |
learn = load_learner('export.pkl', cpu=True) | |
return P("β Model loaded successfully", cls="text-green-600 text-lg") | |
except Exception as e: | |
return Div( | |
H2("β Model Load Failed", cls="text-red-600"), | |
P(f"Error: {str(e)}", cls="text-gray-700"), | |
Pre(traceback.format_exc(), cls="p-4 bg-gray-100 rounded text-sm overflow-auto max-h-[300px]") | |
) | |
def get(): | |
try: | |
learn = load_learner('export.pkl', cpu=True) | |
label, idx, probs = learn.predict('black.jpg') | |
confidence = probs[idx.item()].item() * 100 | |
return Div( | |
H2("β Prediction Successful", cls="text-green-600"), | |
P(f"Label: {label}"), | |
P(f"Confidence: {confidence:.2f}%") | |
) | |
except Exception as e: | |
return Div( | |
H2("β Prediction Failed", cls="text-red-600"), | |
P(f"Error: {str(e)}"), | |
Pre(traceback.format_exc(), cls="p-4 bg-gray-100 rounded text-sm overflow-auto max-h-[300px]") | |
) | |
setup_hf_backup(app) | |
serve() | |