chychiu commited on
Commit
4598f99
1 Parent(s): 4d6ee04

final submission yolo

Browse files
Files changed (1) hide show
  1. script.py +35 -35
script.py CHANGED
@@ -313,41 +313,41 @@ def make_submission(metadata_df):
313
  OUTPUT_CSV_PATH = "./submission.csv"
314
  BASE_CKPT_PATH = "./checkpoints"
315
 
316
- # model_names = [
317
- # "dino_2_optuna_05242231.ckpt",
318
- # "dino_optuna_05241449.ckpt",
319
- # "dino_optuna_05241257.ckpt",
320
- # "dino_optuna_05241222.ckpt",
321
- # "dino_2_optuna_05242055.ckpt",
322
- # "dino_2_optuna_05242156.ckpt",
323
- # "dino_2_optuna_05242344.ckpt",
324
- # ]
325
-
326
- # models = []
327
-
328
- # for model_path in model_names:
329
- # print("loading ", model_path)
330
- # ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
331
-
332
- # ckpt = torch.load(ckpt_path)
333
- # model = FungiMEEModel()
334
- # model.load_state_dict(
335
- # {w: ckpt["model." + w] for w in model.state_dict().keys()}
336
- # )
337
- # model.eval()
338
- # model.cuda()
339
-
340
- # models.append(model)
341
-
342
- # fungi_model = FungiEnsembleModel(models)
343
-
344
- ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
345
-
346
- fungi_model = FungiMEEModel()
347
- ckpt = torch.load(ckpt_path)
348
- fungi_model.load_state_dict(
349
- {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
350
- )
351
 
352
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
353
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)
 
313
  OUTPUT_CSV_PATH = "./submission.csv"
314
  BASE_CKPT_PATH = "./checkpoints"
315
 
316
+ model_names = [
317
+ "dino_2_optuna_05242231.ckpt",
318
+ "dino_optuna_05241449.ckpt",
319
+ "dino_optuna_05241257.ckpt",
320
+ "dino_optuna_05241222.ckpt",
321
+ "dino_2_optuna_05242055.ckpt",
322
+ "dino_2_optuna_05242156.ckpt",
323
+ "dino_2_optuna_05242344.ckpt",
324
+ ]
325
+
326
+ models = []
327
+
328
+ for model_path in model_names:
329
+ print("loading ", model_path)
330
+ ckpt_path = os.path.join(BASE_CKPT_PATH, model_path)
331
+
332
+ ckpt = torch.load(ckpt_path)
333
+ model = FungiMEEModel()
334
+ model.load_state_dict(
335
+ {w: ckpt["model." + w] for w in model.state_dict().keys()}
336
+ )
337
+ model.eval()
338
+ model.cuda()
339
+
340
+ models.append(model)
341
+
342
+ fungi_model = FungiEnsembleModel(models)
343
+
344
+ # ckpt_path = os.path.join(BASE_CKPT_PATH, "dino_2_optuna_05242055.ckpt")
345
+
346
+ # fungi_model = FungiMEEModel()
347
+ # ckpt = torch.load(ckpt_path)
348
+ # fungi_model.load_state_dict(
349
+ # {w: ckpt["model." + w] for w in fungi_model.state_dict().keys()}
350
+ # )
351
 
352
  embedding_dataset = EmbeddingMetadataDataset(metadata_df)
353
  loader = DataLoader(embedding_dataset, batch_size=128, shuffle=False)