stashtag / app.py
cc1234's picture
add new endpoint for marker generation
3048dcc
raw history blame
No virus
6.29 kB
import base64
from uuid import uuid4
import gradio as gr
from fastcore.all import *
from fastai.vision.all import *
import numpy as np
import timm
def parent_labels(o):
"Label `item` with the parent folder name."
return Path(o).parent.name.split(",")
class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat):
def __init__(self, eps:float=0.1, **kwargs):
self.eps = eps
super().__init__(thresh=0.1, **kwargs)
def __call__(self, inp, targ, **kwargs):
targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps
return super().__call__(inp, targ_smooth, **kwargs)
learn = load_learner('models.pkl')
# set a new loss function with a threshold of 0.4 to remove more false positives
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
def predict_tags(image, vtt, threshold=0.4):
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
sprite = PILImage.create(image)
offsets = []
times = []
images = []
frames = []
for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)):
frames.append(i)
times.append(time_seconds)
offsets.append((left, top, right, bottom))
cut_frame = sprite.crop((left, top, left + right, top + bottom))
images.append(PILImage.create(np.asarray(cut_frame)))
# create dataset
threshold = threshold or 0.4
learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold)
test_dl = learn.dls.test_dl(images, bs=64)
# get predictions
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
# swivel into tags list from activations
tags = {}
for idx1, activation in enumerate(activations):
for idx2, i in enumerate(activation):
if not i:
continue
tag = learn.dls.vocab[idx2]
tag = tag.replace("_", " ")
if tag not in tags:
tags[tag] = {'prob': 0, 'offset': (), 'frame': 0}
prob = float(probabilities[idx1][idx2])
if tags[tag]['prob'] < prob:
tags[tag]['prob'] = prob
tags[tag]['offset'] = offsets[idx1]
tags[tag]['frame'] = idx1
tags[tag]['time'] = times[idx1]
return tags
def predict_markers(image, vtt, threshold=0.4):
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
sprite = PILImage.create(image)
offsets = []
times = []
images = []
frames = []
for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)):
frames.append(i)
times.append(time_seconds)
offsets.append((left, top, right, bottom))
cut_frame = sprite.crop((left, top, left + right, top + bottom))
images.append(PILImage.create(np.asarray(cut_frame)))
# create dataset
threshold = threshold or 0.4
learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold)
test_dl = learn.dls.test_dl(images, bs=64)
# get predictions
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
# swivel into tags list from activations
all_data_per_frame = []
for idx1, activation in enumerate(activations):
frame_data = {'offset': offsets[idx1], 'frame': idx1, 'time': times[idx1], 'tags': []}
ftags = []
for idx2, i in enumerate(activation):
if not i:
continue
tag = learn.dls.vocab[idx2]
tag = tag.replace("_", " ")
prob = float(probabilities[idx1][idx2])
ftags.append({'label': tag, 'prob': prob})
if not ftags:
continue
frame_data['tags'] = ftags
all_data_per_frame.append(frame_data)
filtered = []
for idx, frame_data in enumerate(all_data_per_frame):
if idx == len(all_data_per_frame) - 1:
break
next_frame_data = all_data_per_frame[idx + 1]
frame_data['tags'] = [tag for tag in frame_data['tags'] for next_tag in next_frame_data['tags'] if tag['label'] == next_tag['label']]
if frame_data['tags']:
filtered.append(frame_data)
last_tag = set()
results = []
for frame_data in filtered:
tags = {s['label'] for s in frame_data['tags']}
if tags.intersection(last_tag):
continue
last_tag = tags
frame_data['tag'] = sorted(frame_data['tags'], key=lambda x: x['prob'], reverse=True)[0]
del frame_data['tags']
# add unique id to the frame
frame_data['id'] = str(uuid4())
results.append(frame_data)
return results
def getVTToffsets(vtt):
time_seconds = 0
left = top = right = bottom = None
for line in vtt.decode("utf-8").split("\n"):
line = line.strip()
if "-->" in line:
# grab the start time
# 00:00:00.000 --> 00:00:41.000
start = line.split("-->")[0].strip().split(":")
# convert to seconds
time_seconds = (
int(start[0]) * 3600
+ int(start[1]) * 60
+ float(start[2])
)
left = top = right = bottom = None
elif "xywh=" in line:
left, top, right, bottom = line.split("xywh=")[-1].split(",")
left, top, right, bottom = (
int(left),
int(top),
int(right),
int(bottom),
)
else:
continue
if not left:
continue
yield left, top, right, bottom, time_seconds
# create a gradio interface with 2 tabs
tag = gr.Interface(
fn=predict_tags,
inputs=[
gr.Image(),
gr.Textbox(label="VTT file"),
gr.Number(value=0.4, label="Threshold")
],
outputs=gr.JSON(label=""),
)
marker = gr.Interface(
fn=predict_markers,
inputs=[
gr.Image(),
gr.Textbox(label="VTT file"),
gr.Number(value=0.4, label="Threshold")
],
outputs=gr.JSON(label=""),
)
gr.TabbedInterface(
[tag, marker], ["tag", "marker"]
).launch(enable_queue=True, server_name="0.0.0.0")