Kr1n3 commited on
Commit
8415643
1 Parent(s): 260b201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -1,37 +1,32 @@
1
- import torch
2
- import gradio as gr
3
- from huggingface_hub import hf_hub_download
4
- from PIL import Image
5
- pip install -qr https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt #
6
 
7
- REPO_ID = "Kr1n3/Fashion-Items-Classification"
8
- FILENAME = "best.pt"
9
 
10
- yolov5_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
11
 
12
- model = torch.hub.load('ultralytics/yolov5', 'custom', path=yolov5_weights, force_reload=True) # local repo
13
 
14
- def object_detection(im, size=640):
15
- results = model(im) # inference
16
- #results.print() # print results to screen
17
- #results.show() # display results
18
- #results.save() # save as results1.jpg, results2.jpg... etc.
19
- results.render() # updates results.imgs with boxes and labels
20
- return Image.fromarray(results.imgs[0])
21
 
22
  title = "Fashion Items Classification"
23
  description = """
24
  """
25
 
 
 
 
 
 
 
 
26
  image = gr.inputs.Image(shape=(640, 640), image_mode="RGB", source="upload", label="Imagem", optional=False)
27
- outputs = gr.outputs.Image(type="pil", label="Output Image")
28
 
29
  gr.Interface(
30
- fn=object_detection,
31
  inputs=image,
32
- outputs=outputs,
33
  title=title,
34
  description=description,
35
  examples=[["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_30.jpeg?raw=true"], ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_01.jpg?raw=true"],
36
- ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_33.jpg?raw=true"], ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_14.JPG?raw=true"]],
37
- ).launch()
 
1
+ import torch as tf
 
 
 
 
2
 
3
+ import gradio as gr
 
4
 
5
+ inception_net = tf.keras.models.load_model('best.pt') # load the model
6
 
7
+ labels = ['bom', 'ruim']
8
 
 
 
 
 
 
 
 
9
 
10
  title = "Fashion Items Classification"
11
  description = """
12
  """
13
 
14
+ def classify_image(inp):
15
+ inp = inp.reshape((-1, 640, 640, 3))
16
+ inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
17
+ prediction = inception_net.predict(inp).flatten()
18
+ return (labels[1] if float(prediction) >= 0 else labels[0])
19
+
20
+
21
  image = gr.inputs.Image(shape=(640, 640), image_mode="RGB", source="upload", label="Imagem", optional=False)
22
+ label = gr.outputs.Textbox(type="auto", label="Classificação")
23
 
24
  gr.Interface(
25
+ fn=classify_image,
26
  inputs=image,
27
+ outputs=label,
28
  title=title,
29
  description=description,
30
  examples=[["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_30.jpeg?raw=true"], ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_01.jpg?raw=true"],
31
+ ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_14.JPG?raw=true"], ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/dress_45.JPG?raw=true"]],
32
+ ).launch()