etweedy commited on
Commit
e4d812f
·
1 Parent(s): cf5a346

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +40 -0
  2. cat_pot.jpeg +0 -0
  3. chair_sofa.jpeg +0 -0
  4. cow_bike.jpeg +0 -0
  5. dog_plane.jpeg +0 -0
  6. horse_sheep.jpeg +0 -0
  7. obj_class2.pkl +3 -0
  8. pizza.jpeg +0 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.vision.all import *
2
+ import gradio as gr
3
+
4
+ # Define custom functions for the model
5
+ def get_x(r): return path/'train'/r['fname']
6
+ def get_y(r): return r['labels'].split(' ')
7
+ def splitter(df):
8
+ train = df.index[~df['is_valid']].tolist()
9
+ valid = df.index[df['is_valid']].tolist()
10
+ return train,valid
11
+
12
+ # Load the model
13
+ learn=load_learner('obj_class2.pkl')
14
+
15
+ # The loss function has default threshold of 0.5. It seems to do better with 0.3.
16
+ learn.loss_func = BCEWithLogitsLossFlat(thresh=0.3)
17
+
18
+ # Pull out the list of categories from the model
19
+ categories = learn.dls.vocab
20
+ cat_list = [x for x in categories]
21
+
22
+ # Function for classifying image.
23
+ def classify_image(img):
24
+ pred,idx,probs = learn.predict(img)
25
+ idx = list(idx)
26
+ answer = ' and '.join([cat_list[i] for i in np.where(idx)[0].tolist()])
27
+ if answer:
28
+ return answer
29
+ else:
30
+ return "I don't recognize anything..."
31
+
32
+ # Initialize and launch gradio interface
33
+ image = gr.inputs.Image(shape=(192,192))
34
+ label = gr.outputs.Label()
35
+ title = 'Object finder'
36
+ description = "This app will try to find certain types of objects in the photo it's given. Try one of the examples, or upload your own photo! Keep in mind that it only will recognize the following objects: aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, or tvmonitor"
37
+ examples = ['cat_pot.jpeg','cow_bike.jpeg','dog_plane.jpeg','horse_sheep.jpeg','chair_sofa.jpeg','pizza.jpeg']
38
+
39
+ intf = gr.Interface(fn=classify_image,inputs=image,outputs=label,examples=examples, title=title,description=description)
40
+ intf.launch(inline=False)
cat_pot.jpeg ADDED
chair_sofa.jpeg ADDED
cow_bike.jpeg ADDED
dog_plane.jpeg ADDED
horse_sheep.jpeg ADDED
obj_class2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49632ae0d32becd4083ef109c25c3dbd52005fb980299364789a90a3c47fccf0
3
+ size 102996895
pizza.jpeg ADDED