cc1234 commited on
Commit
3048dcc
1 Parent(s): 6fcccb0

add new endpoint for marker generation

Browse files
Files changed (2) hide show
  1. app.py +132 -32
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import base64
 
2
  import gradio as gr
3
 
4
  from fastcore.all import *
5
  from fastai.vision.all import *
6
  import numpy as np
7
- import opennsfw2 as n2
8
  import timm
9
 
10
- model = n2.make_open_nsfw_model()
11
-
12
 
13
  def parent_labels(o):
14
  "Label `item` with the parent folder name."
@@ -29,51 +27,137 @@ learn = load_learner('models.pkl')
29
  learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
30
 
31
 
32
- def predict(image, vtt):
33
  vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
34
  sprite = PILImage.create(image)
35
 
36
- pre_process_data = []
37
- for left, top, right, bottom in getVTToffsets(vtt):
38
- cut_frame = sprite.crop((left, top, left + right, top + bottom))
39
- image = n2.preprocess_image(cut_frame, n2.Preprocessing.YAHOO)
40
- pre_process_data.append((np.expand_dims(image, axis=0), cut_frame, (left, top, right, bottom)))
41
-
42
  offsets = []
 
43
  images = []
44
- tensors = [i[0] for i in pre_process_data]
45
- predictions = model.predict(np.vstack(tensors))
46
- for i, prediction in enumerate(predictions):
47
- if prediction[0] < 0.5:
48
- images.append(PILImage.create(np.asarray(pre_process_data[i][1])))
49
- offsets.append(pre_process_data[i][2])
50
-
 
51
  # create dataset
 
 
52
  test_dl = learn.dls.test_dl(images, bs=64)
53
  # get predictions
54
  probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
 
55
  # swivel into tags list from activations
56
  tags = {}
57
- for x, activation in enumerate(activations):
58
- for idx, i in enumerate(activation):
59
- if i:
60
- tag = learn.dls.vocab[idx]
61
- tag = tag.replace("_", " ")
62
- if tag not in tags:
63
- tags[tag] = {'prob': 0, 'offset': ()}
64
- prob = float(probabilities[x][idx])
65
- if tags[tag]['prob'] < prob:
66
- tags[tag]['prob'] = prob
67
- tags[tag]['offset'] = offsets[x]
 
 
 
 
68
 
69
  return tags
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def getVTToffsets(vtt):
 
73
  left = top = right = bottom = None
74
  for line in vtt.decode("utf-8").split("\n"):
75
  line = line.strip()
76
- if "xywh=" in line:
 
 
 
 
 
 
 
 
 
 
 
 
77
  left, top, right, bottom = line.split("xywh=")[-1].split(",")
78
  left, top, right, bottom = (
79
  int(left),
@@ -87,14 +171,30 @@ def getVTToffsets(vtt):
87
  if not left:
88
  continue
89
 
90
- yield left, top, right, bottom
 
 
91
 
 
 
 
 
 
 
 
 
 
92
 
93
- gr.Interface(
94
- fn=predict,
95
  inputs=[
96
  gr.Image(),
97
  gr.Textbox(label="VTT file"),
 
98
  ],
99
  outputs=gr.JSON(label=""),
 
 
 
 
100
  ).launch(enable_queue=True, server_name="0.0.0.0")
 
1
  import base64
2
+ from uuid import uuid4
3
  import gradio as gr
4
 
5
  from fastcore.all import *
6
  from fastai.vision.all import *
7
  import numpy as np
 
8
  import timm
9
 
 
 
10
 
11
  def parent_labels(o):
12
  "Label `item` with the parent folder name."
 
27
  learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
28
 
29
 
30
+ def predict_tags(image, vtt, threshold=0.4):
31
  vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
32
  sprite = PILImage.create(image)
33
 
 
 
 
 
 
 
34
  offsets = []
35
+ times = []
36
  images = []
37
+ frames = []
38
+ for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)):
39
+ frames.append(i)
40
+ times.append(time_seconds)
41
+ offsets.append((left, top, right, bottom))
42
+ cut_frame = sprite.crop((left, top, left + right, top + bottom))
43
+ images.append(PILImage.create(np.asarray(cut_frame)))
44
+
45
  # create dataset
46
+ threshold = threshold or 0.4
47
+ learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold)
48
  test_dl = learn.dls.test_dl(images, bs=64)
49
  # get predictions
50
  probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
51
+ learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
52
  # swivel into tags list from activations
53
  tags = {}
54
+ for idx1, activation in enumerate(activations):
55
+ for idx2, i in enumerate(activation):
56
+ if not i:
57
+ continue
58
+
59
+ tag = learn.dls.vocab[idx2]
60
+ tag = tag.replace("_", " ")
61
+ if tag not in tags:
62
+ tags[tag] = {'prob': 0, 'offset': (), 'frame': 0}
63
+ prob = float(probabilities[idx1][idx2])
64
+ if tags[tag]['prob'] < prob:
65
+ tags[tag]['prob'] = prob
66
+ tags[tag]['offset'] = offsets[idx1]
67
+ tags[tag]['frame'] = idx1
68
+ tags[tag]['time'] = times[idx1]
69
 
70
  return tags
71
 
72
 
73
+ def predict_markers(image, vtt, threshold=0.4):
74
+ vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", ""))
75
+ sprite = PILImage.create(image)
76
+
77
+ offsets = []
78
+ times = []
79
+ images = []
80
+ frames = []
81
+ for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)):
82
+ frames.append(i)
83
+ times.append(time_seconds)
84
+ offsets.append((left, top, right, bottom))
85
+ cut_frame = sprite.crop((left, top, left + right, top + bottom))
86
+ images.append(PILImage.create(np.asarray(cut_frame)))
87
+
88
+ # create dataset
89
+ threshold = threshold or 0.4
90
+ learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold)
91
+ test_dl = learn.dls.test_dl(images, bs=64)
92
+ # get predictions
93
+ probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True)
94
+ learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4)
95
+
96
+ # swivel into tags list from activations
97
+ all_data_per_frame = []
98
+ for idx1, activation in enumerate(activations):
99
+ frame_data = {'offset': offsets[idx1], 'frame': idx1, 'time': times[idx1], 'tags': []}
100
+ ftags = []
101
+ for idx2, i in enumerate(activation):
102
+ if not i:
103
+ continue
104
+
105
+ tag = learn.dls.vocab[idx2]
106
+ tag = tag.replace("_", " ")
107
+ prob = float(probabilities[idx1][idx2])
108
+ ftags.append({'label': tag, 'prob': prob})
109
+
110
+ if not ftags:
111
+ continue
112
+ frame_data['tags'] = ftags
113
+ all_data_per_frame.append(frame_data)
114
+
115
+ filtered = []
116
+ for idx, frame_data in enumerate(all_data_per_frame):
117
+ if idx == len(all_data_per_frame) - 1:
118
+ break
119
+
120
+ next_frame_data = all_data_per_frame[idx + 1]
121
+ frame_data['tags'] = [tag for tag in frame_data['tags'] for next_tag in next_frame_data['tags'] if tag['label'] == next_tag['label']]
122
+ if frame_data['tags']:
123
+ filtered.append(frame_data)
124
+
125
+ last_tag = set()
126
+ results = []
127
+ for frame_data in filtered:
128
+ tags = {s['label'] for s in frame_data['tags']}
129
+ if tags.intersection(last_tag):
130
+ continue
131
+
132
+ last_tag = tags
133
+ frame_data['tag'] = sorted(frame_data['tags'], key=lambda x: x['prob'], reverse=True)[0]
134
+ del frame_data['tags']
135
+
136
+ # add unique id to the frame
137
+ frame_data['id'] = str(uuid4())
138
+ results.append(frame_data)
139
+
140
+ return results
141
+
142
+
143
  def getVTToffsets(vtt):
144
+ time_seconds = 0
145
  left = top = right = bottom = None
146
  for line in vtt.decode("utf-8").split("\n"):
147
  line = line.strip()
148
+
149
+ if "-->" in line:
150
+ # grab the start time
151
+ # 00:00:00.000 --> 00:00:41.000
152
+ start = line.split("-->")[0].strip().split(":")
153
+ # convert to seconds
154
+ time_seconds = (
155
+ int(start[0]) * 3600
156
+ + int(start[1]) * 60
157
+ + float(start[2])
158
+ )
159
+ left = top = right = bottom = None
160
+ elif "xywh=" in line:
161
  left, top, right, bottom = line.split("xywh=")[-1].split(",")
162
  left, top, right, bottom = (
163
  int(left),
 
171
  if not left:
172
  continue
173
 
174
+ yield left, top, right, bottom, time_seconds
175
+
176
+ # create a gradio interface with 2 tabs
177
 
178
+ tag = gr.Interface(
179
+ fn=predict_tags,
180
+ inputs=[
181
+ gr.Image(),
182
+ gr.Textbox(label="VTT file"),
183
+ gr.Number(value=0.4, label="Threshold")
184
+ ],
185
+ outputs=gr.JSON(label=""),
186
+ )
187
 
188
+ marker = gr.Interface(
189
+ fn=predict_markers,
190
  inputs=[
191
  gr.Image(),
192
  gr.Textbox(label="VTT file"),
193
+ gr.Number(value=0.4, label="Threshold")
194
  ],
195
  outputs=gr.JSON(label=""),
196
+ )
197
+
198
+ gr.TabbedInterface(
199
+ [tag, marker], ["tag", "marker"]
200
  ).launch(enable_queue=True, server_name="0.0.0.0")
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  fastai==2.7.7
2
  numpy==1.24.2
3
  timm==0.6.7
4
- opennsfw2==0.10.2
5
- gradio
 
1
  fastai==2.7.7
2
  numpy==1.24.2
3
  timm==0.6.7
4
+ gradio==3.50.2