Spaces:
Configuration error
Configuration error
englert
commited on
Commit
·
1f88a07
1
Parent(s):
a8b501d
update app.py and resnet50.py
Browse files- app.py +6 -4
- resnet50.py +1 -1
app.py
CHANGED
@@ -34,8 +34,8 @@ def predict(input_file, downsample_size):
|
|
34 |
|
35 |
zip_path = os.path.join(input_file.split('/')[-1][:-4] + ".zip")
|
36 |
|
37 |
-
mean = np.asarray([0.3156024, 0.33569682, 0.34337464])
|
38 |
-
std = np.asarray([0.16568947, 0.17827448, 0.18925823])
|
39 |
|
40 |
img_vecs = []
|
41 |
with torch.no_grad():
|
@@ -46,8 +46,9 @@ def predict(input_file, downsample_size):
|
|
46 |
to_rgb=True)):
|
47 |
in_img = (in_img.astype(np.float32) / 255.)
|
48 |
in_img = (in_img - mean) / std
|
|
|
49 |
in_img = np.transpose(in_img, (0, 3, 1, 2))
|
50 |
-
in_img = torch.from_numpy(in_img)
|
51 |
encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
|
52 |
img_vecs += [encoded]
|
53 |
|
@@ -82,7 +83,8 @@ def predict(input_file, downsample_size):
|
|
82 |
|
83 |
demo = gr.Interface(
|
84 |
fn=predict,
|
85 |
-
inputs=[gr.inputs.Video(label="Upload Video File"),
|
|
|
86 |
outputs=gr.outputs.File(label="Zip"))
|
87 |
|
88 |
demo.launch()
|
|
|
34 |
|
35 |
zip_path = os.path.join(input_file.split('/')[-1][:-4] + ".zip")
|
36 |
|
37 |
+
mean = np.asarray([0.3156024, 0.33569682, 0.34337464], dtype=np.float32)
|
38 |
+
std = np.asarray([0.16568947, 0.17827448, 0.18925823], dtype=np.float32)
|
39 |
|
40 |
img_vecs = []
|
41 |
with torch.no_grad():
|
|
|
46 |
to_rgb=True)):
|
47 |
in_img = (in_img.astype(np.float32) / 255.)
|
48 |
in_img = (in_img - mean) / std
|
49 |
+
in_img = np.expand_dims(in_img, 0)
|
50 |
in_img = np.transpose(in_img, (0, 3, 1, 2))
|
51 |
+
in_img = torch.from_numpy(in_img).float()
|
52 |
encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
|
53 |
img_vecs += [encoded]
|
54 |
|
|
|
83 |
|
84 |
demo = gr.Interface(
|
85 |
fn=predict,
|
86 |
+
inputs=[gr.inputs.Video(label="Upload Video File"),
|
87 |
+
gr.inputs.Number(label="Downsample size")],
|
88 |
outputs=gr.outputs.File(label="Zip"))
|
89 |
|
90 |
demo.launch()
|
resnet50.py
CHANGED
@@ -314,7 +314,7 @@ class ResNet(nn.Module):
|
|
314 |
)[1], 0)
|
315 |
start_idx = 0
|
316 |
for end_idx in idx_crops:
|
317 |
-
_out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)
|
318 |
if start_idx == 0:
|
319 |
output = _out
|
320 |
else:
|
|
|
314 |
)[1], 0)
|
315 |
start_idx = 0
|
316 |
for end_idx in idx_crops:
|
317 |
+
_out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx])) # .cuda(non_blocking=True)
|
318 |
if start_idx == 0:
|
319 |
output = _out
|
320 |
else:
|