import base64 import gradio as gr from fastcore.all import * from fastai.vision.all import * import numpy as np import timm def parent_labels(o): "Label `item` with the parent folder name." return Path(o).parent.name.split(",") class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat): def __init__(self, eps:float=0.1, **kwargs): self.eps = eps super().__init__(thresh=0.1, **kwargs) def __call__(self, inp, targ, **kwargs): targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps return super().__call__(inp, targ_smooth, **kwargs) learn = load_learner('models.pkl') # set a new loss function with a threshold of 0.4 to remove more false positives learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) def predict(image, vtt): vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) sprite = PILImage.create(image) offsets = [] images = [] for left, top, right, bottom in getVTToffsets(vtt): offsets.append((left, top, right, bottom)) cut_frame = sprite.crop((left, top, left + right, top + bottom)) images.append(PILImage.create(np.asarray(cut_frame))) # create dataset test_dl = learn.dls.test_dl(images, bs=64) # get predictions probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) # swivel into tags list from activations tags = {} for x, activation in enumerate(activations): for idx, i in enumerate(activation): if i: tag = learn.dls.vocab[idx] tag = tag.replace("_", " ") if tag not in tags: tags[tag] = {'prob': 0, 'offset': ()} prob = float(probabilities[x][idx]) if tags[tag]['prob'] < prob: tags[tag]['prob'] = prob tags[tag]['offset'] = offsets[x] return tags def getVTToffsets(vtt): left = top = right = bottom = None for line in vtt.decode("utf-8").split("\n"): line = line.strip() if "xywh=" in line: left, top, right, bottom = line.split("xywh=")[-1].split(",") left, top, right, bottom = ( int(left), int(top), int(right), int(bottom), ) else: continue if not left: continue yield left, top, right, bottom gr.Interface( fn=predict, inputs=[ gr.Image(), gr.Textbox(label="VTT file"), ], outputs=gr.JSON(label=""), ).launch(enable_queue=True, server_name="0.0.0.0")