Spaces:
Running
Running
File size: 2,639 Bytes
5e2f8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import base64
import gradio as gr
from fastcore.all import *
from fastai.vision.all import *
import numpy as np
import timm
def parent_labels(o):
"Label `item` with the parent folder name."
return Path(o).parent.name.split(",")
class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat):
def __init__(self, eps:float=0.1, **kwargs):
self.eps = eps
super().__init__(thresh=0.1, **kwargs)
def __call__(self, inp, targ, **kwargs):
targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps
return super().__call__(inp, targ_smooth, **kwargs)
learn = load_learner('models.pkl')
# set a new loss function with a threshold of 0.4 to remove more false positives
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
def predict(image, vtt):
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
sprite = PILImage.create(image)
offsets = []
images = []
for left, top, right, bottom in getVTToffsets(vtt):
offsets.append((left, top, right, bottom))
cut_frame = sprite.crop((left, top, left + right, top + bottom))
images.append(PILImage.create(np.asarray(cut_frame)))
# create dataset
test_dl = learn.dls.test_dl(images, bs=64)
# get predictions
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
# swivel into tags list from activations
tags = {}
for x, activation in enumerate(activations):
for idx, i in enumerate(activation):
if i:
tag = learn.dls.vocab[idx]
tag = tag.replace("_", " ")
if tag not in tags:
tags[tag] = {'prob': 0, 'offset': ()}
prob = float(probabilities[x][idx])
if tags[tag]['prob'] < prob:
tags[tag]['prob'] = prob
tags[tag]['offset'] = offsets[x]
return tags
def getVTToffsets(vtt):
left = top = right = bottom = None
for line in vtt.decode("utf-8").split("\n"):
line = line.strip()
if "xywh=" in line:
left, top, right, bottom = line.split("xywh=")[-1].split(",")
left, top, right, bottom = (
int(left),
int(top),
int(right),
int(bottom),
)
else:
continue
if not left:
continue
yield left, top, right, bottom
gr.Interface(
fn=predict,
inputs=[
gr.Image(),
gr.Textbox(label="VTT file"),
],
outputs=gr.JSON(label=""),
).launch(enable_queue=True, server_name="0.0.0.0")
|