Spaces:
Running
Running
import base64 | |
import gradio as gr | |
from fastcore.all import * | |
from fastai.vision.all import * | |
import numpy as np | |
import timm | |
def parent_labels(o): | |
"Label `item` with the parent folder name." | |
return Path(o).parent.name.split(",") | |
class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat): | |
def __init__(self, eps:float=0.1, **kwargs): | |
self.eps = eps | |
super().__init__(thresh=0.1, **kwargs) | |
def __call__(self, inp, targ, **kwargs): | |
targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps | |
return super().__call__(inp, targ_smooth, **kwargs) | |
learn = load_learner('models.pkl') | |
# set a new loss function with a threshold of 0.4 to remove more false positives | |
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
def predict(image, vtt): | |
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) | |
sprite = PILImage.create(image) | |
offsets = [] | |
images = [] | |
for left, top, right, bottom in getVTToffsets(vtt): | |
offsets.append((left, top, right, bottom)) | |
cut_frame = sprite.crop((left, top, left + right, top + bottom)) | |
images.append(PILImage.create(np.asarray(cut_frame))) | |
# create dataset | |
test_dl = learn.dls.test_dl(images, bs=64) | |
# get predictions | |
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) | |
# swivel into tags list from activations | |
tags = {} | |
for x, activation in enumerate(activations): | |
for idx, i in enumerate(activation): | |
if i: | |
tag = learn.dls.vocab[idx] | |
tag = tag.replace("_", " ") | |
if tag not in tags: | |
tags[tag] = {'prob': 0, 'offset': ()} | |
prob = float(probabilities[x][idx]) | |
if tags[tag]['prob'] < prob: | |
tags[tag]['prob'] = prob | |
tags[tag]['offset'] = offsets[x] | |
return tags | |
def getVTToffsets(vtt): | |
left = top = right = bottom = None | |
for line in vtt.decode("utf-8").split("\n"): | |
line = line.strip() | |
if "xywh=" in line: | |
left, top, right, bottom = line.split("xywh=")[-1].split(",") | |
left, top, right, bottom = ( | |
int(left), | |
int(top), | |
int(right), | |
int(bottom), | |
) | |
else: | |
continue | |
if not left: | |
continue | |
yield left, top, right, bottom | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(), | |
gr.Textbox(label="VTT file"), | |
], | |
outputs=gr.JSON(label=""), | |
).launch(enable_queue=True, server_name="0.0.0.0") | |