File size: 2,619 Bytes
2d998c0
 
 
894991e
2d998c0
451c856
 
 
1722c72
451c856
 
 
894991e
 
 
efdf5ed
894991e
 
 
2d998c0
 
894991e
 
 
2d998c0
 
 
 
 
 
894991e
 
 
 
 
 
 
 
 
 
 
 
 
ed85e7f
 
 
 
 
 
 
 
894991e
 
 
2d998c0
894991e
 
 
2d998c0
894991e
 
 
 
 
 
 
 
 
2d998c0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fasthtml.common import *
from fasthtml.components import *
from fastai.vision.all import *
import pathlib

style = Style("""
    body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; max-width: 800px; margin: 0 auto; padding: 20px; }
    h2 { color: #2c3e50; }
    img { height: 100px; width: auto; }
    .image-box { display: flex; }
"""),

# Ensure the uploads directory exists
os.makedirs("uploads", exist_ok=True)

app, rt = fast_app(hdrs=(style), )

plt = platform.system()
if plt == 'Windows': pathlib.PosixPath = pathlib.WindowsPath
learn = load_learner("model.pkl")

def classify_image(image_path):
    img = load_image(image_path)
    char,idx,probs = learn.predict(img)   
    name = " ".join([s.capitalize() for s in (char).split("_")])
    return name, idx, probs

@rt('/')
def index():
    
    return (Titled("Chair vs Lamp Classifier"), 
            Main(H2("Upload an Image"),
                Form(
                    Input(type="file", name="img_file", accept="image/*", required=True),
                    Button("Classify"),
                    enctype="multipart/form-data",
                    hx_post="/classify",
                    hx_target="#result"
                ),
                Br(),
                Div(id="result"),
                H2("Test Images"),
                Div(
                    Img(src="chair1.jpg", hx_trigger="click", hx_get="/classify", 
                        hx_target="#result", hx_vals='{"image_path": "chair1.jpg"}'),
                    Img(src="chair2.jpg", hx_trigger="click", hx_get="/classify", 
                        hx_target="#result", hx_vals='{"image_path": "chair2.jpg"}'),
                    Img(src="lamp1.jpg", hx_trigger="click", hx_get="/classify", 
                        hx_target="#result", hx_vals='{"image_path": "lamp1.jpg"}'),
                    Img(src="lamp2.jpg", hx_trigger="click", hx_get="/classify", 
                        hx_target="#result", hx_vals='{"image_path": "lamp2.jpg"}'),
                    cls="image-box"
                )
            ))
    
@rt('/classify', methods=['GET', 'POST'])
async def classify(img_file: UploadFile|None = None, 
                   image_path:Str|None = None):
    
    if img_file:
        # Save the uploaded image
        image_path = f"uploads/{img_file.filename}"
        with open(image_path, "wb") as f:
            f.write(await img_file.read())
            
    name, idx, probs = classify_image(image_path)    
    return Img(src=image_path), Div(P(f"This is a {name}."),
               P(f"Probability that it's a {name}: {probs[idx]:.4f}"))

serve()