geninhu commited on
Commit
0cb5c0e
1 Parent(s): ee501ae

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -21,7 +21,7 @@ def _get_layers(arch:str, pretrained=True):
21
  "Get the layers and arch for a VGG Model (16 and 19 are supported only)"
22
  feat_net = vgg19(pretrained=pretrained) if arch.find('9') > 1 else vgg16(pretrained=pretrained)
23
  config = _vgg_config.get(arch)
24
- features = feat_net.features.cuda().eval()
25
  for p in features.parameters(): p.requires_grad=False
26
  return feat_net, [features[i] for i in config]
27
 
@@ -38,7 +38,7 @@ learner = from_pretrained_fastai(repo_id)
38
 
39
  def infer(img):
40
  pred = learner.predict(img)
41
- image = pred[0].cpu().numpy()
42
  image = image.transpose((1, 2, 0))
43
  plt.imshow(image)
44
  return plt.gcf() #pred[0].show()
 
21
  "Get the layers and arch for a VGG Model (16 and 19 are supported only)"
22
  feat_net = vgg19(pretrained=pretrained) if arch.find('9') > 1 else vgg16(pretrained=pretrained)
23
  config = _vgg_config.get(arch)
24
+ features = feat_net.features.eval()
25
  for p in features.parameters(): p.requires_grad=False
26
  return feat_net, [features[i] for i in config]
27
 
 
38
 
39
  def infer(img):
40
  pred = learner.predict(img)
41
+ image = pred[0].numpy()
42
  image = image.transpose((1, 2, 0))
43
  plt.imshow(image)
44
  return plt.gcf() #pred[0].show()