File size: 3,288 Bytes
988a945
 
 
 
 
 
304ad99
2ea737d
988a945
 
 
 
 
307f250
41f68b4
4c1a80d
41f68b4
 
 
19a7d81
304ad99
 
7988f17
19a7d81
 
 
 
0d3a1e8
76818d8
7147635
efc152c
9915153
2ea737d
9915153
 
ed609ce
 
 
9915153
ed609ce
 
 
9915153
ed609ce
 
 
9915153
ed609ce
 
 
9915153
ed609ce
 
 
9915153
ed609ce
 
 
9915153
ed609ce
 
64fa63b
ed609ce
2ea737d
7988f17
efc152c
7988f17
 
1e0a18d
19a7d81
 
 
 
 
 
 
 
 
9fa2277
ad1c17e
 
 
1b61e41
 
ad1c17e
ae0f8d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a7d81
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr 
from PIL import Image
import os
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from ultralytics.utils.plotting import Annotator, colors 
import numpy as np
import yaml
from huggingface_hub import hf_hub_download
from ultralytics import YOLO

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model =  YOLO('Models/best.pt')

model = model.to(device)

def load_img (filename):
    if isinstance(img,str):
        img = get_url_img(img) if img.startswith('http') else Image.open(img).convert('RGB')
    return img

def process_img(image):

    with torch.no_grad():
        result = model(source=image)
        lbel=''
        if len(result[0].boxes)>0:
            ann=Annotator(im=image)
            boxes=result[0].boxes
            for element in boxes:
                box=np.array(element.xyxy.cpu()).flatten()
                if element.cls[0].cpu().numpy()==2.0:
                    lbel='car'
                    clr=(0,255,0)
                    
                if element.cls[0].cpu().numpy()==0.0:
                    lbel='bicycle'
                    clr=(255,0,0)
                    
                if element.cls[0].cpu().numpy()==1.0:
                    lbel='bus'
                    clr=(0,0,255)
                    
                if element.cls[0].cpu().numpy()==3.0:
                    lbel='motorcycle'
                    clr=(255,0,255)
                    
                if element.cls[0].cpu().numpy()==4.0:
                    lbel='person'
                    clr=(255,128,0)
                    
                if element.cls[0].cpu().numpy()==5.0:
                    lbel='train'
                    clr=(255,0,128)
                    
                if element.cls[0].cpu().numpy()==6.0:
                    lbel='truck'
                    clr=(0,255,255)
                    
                ann.box_label(box=box, label=lbel, color=clr)
            vis=ann.result()
        else:
            vis = image
    return vis
    
title = "Efficient Hazy Vehicle Detection ✏️🚗🤗"
description = ''' ## [Efficient Hazy Vehicle Detection](https://github.com/cidautai)
[Paula Garrido Mellado](https://github.com/paugar5)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''

examples = [['examples/dusttornado.jpg'],
            ['examples/foggy.jpg'], 
            ['examples/haze.jpg'], 
            ["examples/mist.jpg"], 
            ["examples/rain_storm.jpg"],
           ["examples/sand_storm.jpg"],
           ["examples/snow_storm.jpg"]]

css = """
    .image-frame img, .image-container img {
        width: auto;
        height: auto;
        max-width: none;
    }
"""

demo = gr.Interface(
    fn = process_img,
    inputs = [
            gr.Image(type = 'pil', label = 'input')
    ],
    outputs = [gr.Image(type='pil', label = 'output')],
    title = title,
    description = description,
    examples = examples,
    css = css
)

if __name__ == '__main__':
    demo.launch()