fohy24 commited on
Commit
3802471
·
1 Parent(s): e6445f0

get IMAGE_SIZE from checkpoint

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -47,6 +47,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
  hf_token = os.getenv('HF_token')
48
  model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token)
49
  checkpoint = torch.load(model_path, map_location=device)
 
50
 
51
  new_layers = nn.Sequential(
52
  nn.LazyLinear(2048),
@@ -56,7 +57,7 @@ new_layers = nn.Sequential(
56
  nn.LazyLinear(num_labels)
57
  )
58
 
59
- IMAGE_SIZE = 480
60
  transform = v2.Compose([
61
  v2.ToImage(),
62
  v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
 
47
  hf_token = os.getenv('HF_token')
48
  model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token)
49
  checkpoint = torch.load(model_path, map_location=device)
50
+ print(checkpoint.keys())
51
 
52
  new_layers = nn.Sequential(
53
  nn.LazyLinear(2048),
 
57
  nn.LazyLinear(num_labels)
58
  )
59
 
60
+ IMAGE_SIZE = checkpoint['image_size']
61
  transform = v2.Compose([
62
  v2.ToImage(),
63
  v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),