Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -93,8 +93,8 @@ size = 256
|
|
93 |
means = [0.485, 0.456, 0.406]
|
94 |
stds = [0.229, 0.224, 0.225]
|
95 |
|
96 |
-
t_stds = torch.tensor(stds).cpu()[:,None,None]
|
97 |
-
t_means = torch.tensor(means).cpu()[:,None,None]
|
98 |
|
99 |
def makeEven(_x):
|
100 |
return int(_x) if (_x % 2 == 0) else int(_x+1)
|
@@ -107,7 +107,7 @@ def tensor2im(var):
|
|
107 |
return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
|
108 |
|
109 |
def proc_pil_img(input_image, model):
|
110 |
-
transformed_image = img_transforms(input_image)[None,...].cpu()()
|
111 |
|
112 |
with torch.no_grad():
|
113 |
result_image = model(transformed_image)[0]
|
@@ -118,9 +118,9 @@ def proc_pil_img(input_image, model):
|
|
118 |
|
119 |
|
120 |
|
121 |
-
modelv4 = torch.jit.load(modelarcanev4).eval().cpu()
|
122 |
-
modelv3 = torch.jit.load(modelarcanev3).eval().cpu()
|
123 |
-
modelv2 = torch.jit.load(modelarcanev2).eval().cpu()
|
124 |
|
125 |
def process(im, version):
|
126 |
if version == 'version 0.4':
|
|
|
93 |
means = [0.485, 0.456, 0.406]
|
94 |
stds = [0.229, 0.224, 0.225]
|
95 |
|
96 |
+
t_stds = torch.tensor(stds).cpu().half().float()[:,None,None]
|
97 |
+
t_means = torch.tensor(means).cpu().half().float()[:,None,None]
|
98 |
|
99 |
def makeEven(_x):
|
100 |
return int(_x) if (_x % 2 == 0) else int(_x+1)
|
|
|
107 |
return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
|
108 |
|
109 |
def proc_pil_img(input_image, model):
|
110 |
+
transformed_image = img_transforms(input_image)[None,...].cpu().half().float()
|
111 |
|
112 |
with torch.no_grad():
|
113 |
result_image = model(transformed_image)[0]
|
|
|
118 |
|
119 |
|
120 |
|
121 |
+
modelv4 = torch.jit.load(modelarcanev4,map_location='cpu').eval().cpu().half().float()
|
122 |
+
modelv3 = torch.jit.load(modelarcanev3,map_location='cpu').eval().cpu().half().float()
|
123 |
+
modelv2 = torch.jit.load(modelarcanev2,map_location='cpu').eval().cpu().half().float()
|
124 |
|
125 |
def process(im, version):
|
126 |
if version == 'version 0.4':
|