Vrk commited on
Commit
ddcd2c4
1 Parent(s): 6f8bfb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -34,6 +34,19 @@ def get_model(model_name, classes, device):
34
  model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))
35
 
36
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def make_predictions(input_img, model_name):
39
  classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
34
  model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))
35
 
36
  return model
37
+
38
+ def get_transform(input_img, device):
39
+ normalize = transforms.Normalize(
40
+ [0.485, 0.456, 0.406],
41
+ [0.229, 0.224, 0.225]
42
+ )
43
+
44
+ test_transform = transforms.Compose([
45
+ transforms.ToTensor(),
46
+ normalize,
47
+ ])
48
+ input_img = test_transform(input_img).unsqueeze(0).to(device)
49
+ return input_img
50
 
51
  def make_predictions(input_img, model_name):
52
  classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']