cc1234 commited on
Commit
5e2f8af
1 Parent(s): 9f0e41a

v0.1 release

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -1
  2. .gitignore +3 -0
  3. app.py +88 -0
  4. models.pkl +3 -0
  5. requirements.txt +4 -0
.gitattributes CHANGED
@@ -2,7 +2,6 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
@@ -32,3 +31,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ *.db filter=lfs diff=lfs merge=lfs -text
35
+ face.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ flagged
3
+ __pycache__
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gradio as gr
3
+
4
+ from fastcore.all import *
5
+ from fastai.vision.all import *
6
+ import numpy as np
7
+ import timm
8
+
9
+
10
+ def parent_labels(o):
11
+ "Label `item` with the parent folder name."
12
+ return Path(o).parent.name.split(",")
13
+
14
+ class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat):
15
+ def __init__(self, eps:float=0.1, **kwargs):
16
+ self.eps = eps
17
+ super().__init__(thresh=0.1, **kwargs)
18
+
19
+ def __call__(self, inp, targ, **kwargs):
20
+ targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps
21
+ return super().__call__(inp, targ_smooth, **kwargs)
22
+
23
+
24
+ learn = load_learner('models.pkl')
25
+ # set a new loss function with a threshold of 0.4 to remove more false positives
26
+ learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
27
+
28
+
29
+ def predict(image, vtt):
30
+ vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
31
+ sprite = PILImage.create(image)
32
+ offsets = []
33
+ images = []
34
+ for left, top, right, bottom in getVTToffsets(vtt):
35
+ offsets.append((left, top, right, bottom))
36
+ cut_frame = sprite.crop((left, top, left + right, top + bottom))
37
+ images.append(PILImage.create(np.asarray(cut_frame)))
38
+
39
+ # create dataset
40
+ test_dl = learn.dls.test_dl(images, bs=64)
41
+ # get predictions
42
+ probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
43
+ # swivel into tags list from activations
44
+ tags = {}
45
+ for x, activation in enumerate(activations):
46
+ for idx, i in enumerate(activation):
47
+ if i:
48
+ tag = learn.dls.vocab[idx]
49
+ tag = tag.replace("_", " ")
50
+ if tag not in tags:
51
+ tags[tag] = {'prob': 0, 'offset': ()}
52
+ prob = float(probabilities[x][idx])
53
+ if tags[tag]['prob'] < prob:
54
+ tags[tag]['prob'] = prob
55
+ tags[tag]['offset'] = offsets[x]
56
+
57
+ return tags
58
+
59
+
60
+ def getVTToffsets(vtt):
61
+ left = top = right = bottom = None
62
+ for line in vtt.decode("utf-8").split("\n"):
63
+ line = line.strip()
64
+ if "xywh=" in line:
65
+ left, top, right, bottom = line.split("xywh=")[-1].split(",")
66
+ left, top, right, bottom = (
67
+ int(left),
68
+ int(top),
69
+ int(right),
70
+ int(bottom),
71
+ )
72
+ else:
73
+ continue
74
+
75
+ if not left:
76
+ continue
77
+
78
+ yield left, top, right, bottom
79
+
80
+
81
+ gr.Interface(
82
+ fn=predict,
83
+ inputs=[
84
+ gr.Image(),
85
+ gr.Textbox(label="VTT file"),
86
+ ],
87
+ outputs=gr.JSON(label=""),
88
+ ).launch(enable_queue=True, server_name="0.0.0.0")
models.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b56270c11a4aad0963d314f20806e72b911440771777044dbc1ad5ad9de48e5e
3
+ size 22972595
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastai==2.7.7
2
+ numpy==1.21.5
3
+ timm==0.6.7
4
+ gradio==3.18