File size: 2,639 Bytes
5e2f8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import base64
import gradio as gr

from fastcore.all import *
from fastai.vision.all import *
import numpy as np
import timm


def parent_labels(o):
    "Label `item` with the parent folder name."
    return Path(o).parent.name.split(",")

class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat):
    def __init__(self, eps:float=0.1, **kwargs):
        self.eps = eps
        super().__init__(thresh=0.1, **kwargs)
    
    def __call__(self, inp, targ, **kwargs):
        targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps
        return super().__call__(inp, targ_smooth, **kwargs)


learn = load_learner('models.pkl')
# set a new loss function with a threshold of 0.4 to remove more false positives
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)


def predict(image, vtt):
    vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
    sprite = PILImage.create(image)
    offsets = []
    images = []
    for left, top, right, bottom in getVTToffsets(vtt):
        offsets.append((left, top, right, bottom))
        cut_frame = sprite.crop((left, top, left + right, top + bottom))
        images.append(PILImage.create(np.asarray(cut_frame)))

    # create dataset
    test_dl = learn.dls.test_dl(images, bs=64)
    # get predictions
    probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
    # swivel into tags list from activations
    tags = {}
    for x, activation in enumerate(activations):
        for idx, i in enumerate(activation):
            if i:
                tag = learn.dls.vocab[idx]
                tag = tag.replace("_", " ")
                if tag not in tags:
                    tags[tag] = {'prob': 0, 'offset': ()}
                prob = float(probabilities[x][idx])
                if tags[tag]['prob'] < prob:
                    tags[tag]['prob'] = prob
                    tags[tag]['offset'] = offsets[x]

    return tags


def getVTToffsets(vtt):
    left = top = right = bottom = None
    for line in vtt.decode("utf-8").split("\n"):
        line = line.strip()
        if "xywh=" in line:
            left, top, right, bottom = line.split("xywh=")[-1].split(",")
            left, top, right, bottom = (
                int(left),
                int(top),
                int(right),
                int(bottom),
            )
        else:
            continue

        if not left:
            continue

        yield left, top, right, bottom


gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(),
        gr.Textbox(label="VTT file"),
    ],
    outputs=gr.JSON(label=""),
).launch(enable_queue=True, server_name="0.0.0.0")