Sadjad Alikhani
commited on
Update inference.py
Browse files- inference.py +1 -6
inference.py
CHANGED
@@ -33,7 +33,7 @@ if torch.cuda.is_available():
|
|
33 |
# Folders
|
34 |
# MODELS_FOLDER = 'models/'
|
35 |
|
36 |
-
def dataset_gen(preprocessed_chs, input_type,
|
37 |
|
38 |
if input_type in ['cls_emb', 'channel_emb']:
|
39 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
@@ -41,11 +41,6 @@ def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
|
|
41 |
dataset = create_raw_dataset(preprocessed_chs, device)
|
42 |
|
43 |
if input_type in ['cls_emb','channel_emb']:
|
44 |
-
# model = LWM().to(device)
|
45 |
-
# ckpt_name = 'model_weights.pth'
|
46 |
-
# ckpt_path = os.path.join(MODELS_FOLDER, ckpt_name)
|
47 |
-
# lwm_model = load_model(model, ckpt_path, device)
|
48 |
-
# print(f"Model loaded successfully on {device}")
|
49 |
|
50 |
# Process data through LWM
|
51 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
|
|
33 |
# Folders
|
34 |
# MODELS_FOLDER = 'models/'
|
35 |
|
36 |
+
def dataset_gen(preprocessed_chs, input_type, lwm_model):
|
37 |
|
38 |
if input_type in ['cls_emb', 'channel_emb']:
|
39 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
|
|
41 |
dataset = create_raw_dataset(preprocessed_chs, device)
|
42 |
|
43 |
if input_type in ['cls_emb','channel_emb']:
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Process data through LWM
|
46 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|