shewster commited on
Commit
28212a0
1 Parent(s): 661194b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -21
app.py CHANGED
@@ -1,29 +1,25 @@
1
  import gradio as gr
2
  from fastai.vision.all import *
3
- import PIL.Image
4
 
5
- # Load the learner
6
- learn = load_learner('export.pkl')
 
 
 
7
 
8
- # Get the labels from the learner's vocabulary
9
- labels = learn.dls.vocab
10
 
11
- # Define the prediction function
12
- def predict(img):
13
- img = PILImage.create(img)
14
  pred, pred_idx, probs = learn.predict(img)
15
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
16
 
17
- # Define the title, description, and article for the interface
18
- title = "Cashew Classifier"
19
- description = "Classify images of cashews."
20
- article = "<p style='text-align: center'><a href='https://tmabraham.github.io/blog/gradio_hf_spaces_tutorial' target='_blank'>Blog post</a></p>"
 
 
21
 
22
- # Create and launch the Gradio interface
23
- gr.Interface(fn=predict,
24
- inputs=gr.Image(),
25
- outputs=gr.Label(num_top_classes=3),
26
- title=title,
27
- description=description,
28
- article=article,
29
- enable_queue=True).launch()
 
1
  import gradio as gr
2
  from fastai.vision.all import *
 
3
 
4
+ # Load your trained model
5
+ def load_model():
6
+ path = 'export.pkl'
7
+ learn = load_learner(path)
8
+ return learn
9
 
10
+ learn = load_model()
 
11
 
12
+ # Define prediction function
13
+ def predict_image(img):
 
14
  pred, pred_idx, probs = learn.predict(img)
15
+ return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(learn.dls.vocab))}
16
 
17
+ # Create a Gradio interface
18
+ interface = gr.Interface(fn=predict_image,
19
+ inputs=gr.inputs.Image(type="pil"),
20
+ outputs=gr.outputs.Label(num_top_classes=3),
21
+ title="Image Classifier",
22
+ description="Upload an image to classify.")
23
 
24
+ if __name__ == "__main__":
25
+ interface.launch()