stashtag / app.py
cc1234's picture
v0.1 release
5e2f8af
raw
history blame
No virus
2.64 kB
import base64
import gradio as gr
from fastcore.all import *
from fastai.vision.all import *
import numpy as np
import timm
def parent_labels(o):
"Label `item` with the parent folder name."
return Path(o).parent.name.split(",")
class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat):
def __init__(self, eps:float=0.1, **kwargs):
self.eps = eps
super().__init__(thresh=0.1, **kwargs)
def __call__(self, inp, targ, **kwargs):
targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps
return super().__call__(inp, targ_smooth, **kwargs)
learn = load_learner('models.pkl')
# set a new loss function with a threshold of 0.4 to remove more false positives
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
def predict(image, vtt):
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
sprite = PILImage.create(image)
offsets = []
images = []
for left, top, right, bottom in getVTToffsets(vtt):
offsets.append((left, top, right, bottom))
cut_frame = sprite.crop((left, top, left + right, top + bottom))
images.append(PILImage.create(np.asarray(cut_frame)))
# create dataset
test_dl = learn.dls.test_dl(images, bs=64)
# get predictions
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
# swivel into tags list from activations
tags = {}
for x, activation in enumerate(activations):
for idx, i in enumerate(activation):
if i:
tag = learn.dls.vocab[idx]
tag = tag.replace("_", " ")
if tag not in tags:
tags[tag] = {'prob': 0, 'offset': ()}
prob = float(probabilities[x][idx])
if tags[tag]['prob'] < prob:
tags[tag]['prob'] = prob
tags[tag]['offset'] = offsets[x]
return tags
def getVTToffsets(vtt):
left = top = right = bottom = None
for line in vtt.decode("utf-8").split("\n"):
line = line.strip()
if "xywh=" in line:
left, top, right, bottom = line.split("xywh=")[-1].split(",")
left, top, right, bottom = (
int(left),
int(top),
int(right),
int(bottom),
)
else:
continue
if not left:
continue
yield left, top, right, bottom
gr.Interface(
fn=predict,
inputs=[
gr.Image(),
gr.Textbox(label="VTT file"),
],
outputs=gr.JSON(label=""),
).launch(enable_queue=True, server_name="0.0.0.0")