Sadjad Alikhani commited on
Commit
6920787
·
verified ·
1 Parent(s): e833bfc

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +1 -6
inference.py CHANGED
@@ -33,7 +33,7 @@ if torch.cuda.is_available():
33
  # Folders
34
  # MODELS_FOLDER = 'models/'
35
 
36
- def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
37
 
38
  if input_type in ['cls_emb', 'channel_emb']:
39
  dataset = prepare_for_LWM(preprocessed_chs, device)
@@ -41,11 +41,6 @@ def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
41
  dataset = create_raw_dataset(preprocessed_chs, device)
42
 
43
  if input_type in ['cls_emb','channel_emb']:
44
- # model = LWM().to(device)
45
- # ckpt_name = 'model_weights.pth'
46
- # ckpt_path = os.path.join(MODELS_FOLDER, ckpt_name)
47
- # lwm_model = load_model(model, ckpt_path, device)
48
- # print(f"Model loaded successfully on {device}")
49
 
50
  # Process data through LWM
51
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)
 
33
  # Folders
34
  # MODELS_FOLDER = 'models/'
35
 
36
+ def dataset_gen(preprocessed_chs, input_type, lwm_model):
37
 
38
  if input_type in ['cls_emb', 'channel_emb']:
39
  dataset = prepare_for_LWM(preprocessed_chs, device)
 
41
  dataset = create_raw_dataset(preprocessed_chs, device)
42
 
43
  if input_type in ['cls_emb','channel_emb']:
 
 
 
 
 
44
 
45
  # Process data through LWM
46
  lwm_loss, embedding_data = evaluate(lwm_model, dataset)