import base64 from uuid import uuid4 import gradio as gr from fastcore.all import * from fastai.vision.all import * import numpy as np import timm def parent_labels(o): "Label `item` with the parent folder name." return Path(o).parent.name.split(",") class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat): def __init__(self, eps:float=0.1, **kwargs): self.eps = eps super().__init__(thresh=0.1, **kwargs) def __call__(self, inp, targ, **kwargs): targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps return super().__call__(inp, targ_smooth, **kwargs) learn = load_learner('models.pkl') # set a new loss function with a threshold of 0.4 to remove more false positives learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) def predict_tags(image, vtt, threshold=0.4): vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) sprite = PILImage.create(image) offsets = [] times = [] images = [] frames = [] for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)): frames.append(i) times.append(time_seconds) offsets.append((left, top, right, bottom)) cut_frame = sprite.crop((left, top, left + right, top + bottom)) images.append(PILImage.create(np.asarray(cut_frame))) # create dataset threshold = threshold or 0.4 learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold) test_dl = learn.dls.test_dl(images, bs=64) # get predictions probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) # swivel into tags list from activations tags = {} for idx1, activation in enumerate(activations): for idx2, i in enumerate(activation): if not i: continue tag = learn.dls.vocab[idx2] tag = tag.replace("_", " ") if tag not in tags: tags[tag] = {'prob': 0, 'offset': (), 'frame': 0} prob = float(probabilities[idx1][idx2]) if tags[tag]['prob'] < prob: tags[tag]['prob'] = prob tags[tag]['offset'] = offsets[idx1] tags[tag]['frame'] = idx1 tags[tag]['time'] = times[idx1] return tags def predict_markers(image, vtt, threshold=0.4): vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) sprite = PILImage.create(image) offsets = [] times = [] images = [] frames = [] for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)): frames.append(i) times.append(time_seconds) offsets.append((left, top, right, bottom)) cut_frame = sprite.crop((left, top, left + right, top + bottom)) images.append(PILImage.create(np.asarray(cut_frame))) # create dataset threshold = threshold or 0.4 learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold) test_dl = learn.dls.test_dl(images, bs=64) # get predictions probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) # swivel into tags list from activations all_data_per_frame = [] for idx1, activation in enumerate(activations): frame_data = {'offset': offsets[idx1], 'frame': idx1, 'time': times[idx1], 'tags': []} ftags = [] for idx2, i in enumerate(activation): if not i: continue tag = learn.dls.vocab[idx2] tag = tag.replace("_", " ") prob = float(probabilities[idx1][idx2]) ftags.append({'label': tag, 'prob': prob}) if not ftags: continue frame_data['tags'] = ftags all_data_per_frame.append(frame_data) filtered = [] for idx, frame_data in enumerate(all_data_per_frame): if idx == len(all_data_per_frame) - 1: break next_frame_data = all_data_per_frame[idx + 1] frame_data['tags'] = [tag for tag in frame_data['tags'] for next_tag in next_frame_data['tags'] if tag['label'] == next_tag['label']] if frame_data['tags']: filtered.append(frame_data) last_tag = set() results = [] for frame_data in filtered: tags = {s['label'] for s in frame_data['tags']} if tags.intersection(last_tag): continue last_tag = tags frame_data['tag'] = sorted(frame_data['tags'], key=lambda x: x['prob'], reverse=True)[0] del frame_data['tags'] # add unique id to the frame frame_data['id'] = str(uuid4()) results.append(frame_data) return results def getVTToffsets(vtt): time_seconds = 0 left = top = right = bottom = None for line in vtt.decode("utf-8").split("\n"): line = line.strip() if "-->" in line: # grab the start time # 00:00:00.000 --> 00:00:41.000 start = line.split("-->")[0].strip().split(":") # convert to seconds time_seconds = ( int(start[0]) * 3600 + int(start[1]) * 60 + float(start[2]) ) left = top = right = bottom = None elif "xywh=" in line: left, top, right, bottom = line.split("xywh=")[-1].split(",") left, top, right, bottom = ( int(left), int(top), int(right), int(bottom), ) else: continue if not left: continue yield left, top, right, bottom, time_seconds # create a gradio interface with 2 tabs tag = gr.Interface( fn=predict_tags, inputs=[ gr.Image(), gr.Textbox(label="VTT file"), gr.Number(value=0.4, label="Threshold") ], outputs=gr.JSON(label=""), ) marker = gr.Interface( fn=predict_markers, inputs=[ gr.Image(), gr.Textbox(label="VTT file"), gr.Number(value=0.4, label="Threshold") ], outputs=gr.JSON(label=""), ) gr.TabbedInterface( [tag, marker], ["tag", "marker"] ).launch(enable_queue=True, server_name="0.0.0.0")