Spaces:
Running
Running
Sadjad Alikhani
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -139,7 +139,7 @@ def identical_train_test_split(output_emb, output_raw, labels, percentage_idx):
|
|
139 |
indices = torch.randperm(N) # Randomly shuffle the indices
|
140 |
|
141 |
# Calculate the split index
|
142 |
-
split_index = int(N * percentage_values[percentage_idx])
|
143 |
print(f'Training Size: {split_index}')
|
144 |
|
145 |
# Split indices into train and test
|
@@ -211,13 +211,11 @@ def process_hdf5_file(uploaded_file, percentage_idx):
|
|
211 |
|
212 |
# Step 7: Tokenize the data using the tokenizer from input_preprocess
|
213 |
preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
|
214 |
-
print(preprocessed_chs[0][0][1])
|
215 |
|
216 |
# Step 7: Perform inference using the functions from inference.py
|
217 |
output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
|
218 |
-
#print(f'output_emb:{output_emb[10][0]}')
|
219 |
output_raw = inference.create_raw_dataset(preprocessed_chs, device)
|
220 |
-
#print(f'output_raw:{output_raw[10][0]}')
|
221 |
|
222 |
print(f"Output Embeddings Shape: {output_emb.shape}")
|
223 |
print(f"Output Raw Shape: {output_raw.shape}")
|
|
|
139 |
indices = torch.randperm(N) # Randomly shuffle the indices
|
140 |
|
141 |
# Calculate the split index
|
142 |
+
split_index = int(N * percentage_values[percentage_idx-1]/10)
|
143 |
print(f'Training Size: {split_index}')
|
144 |
|
145 |
# Split indices into train and test
|
|
|
211 |
|
212 |
# Step 7: Tokenize the data using the tokenizer from input_preprocess
|
213 |
preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
|
214 |
+
#print(preprocessed_chs[0][0][1])
|
215 |
|
216 |
# Step 7: Perform inference using the functions from inference.py
|
217 |
output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
|
|
|
218 |
output_raw = inference.create_raw_dataset(preprocessed_chs, device)
|
|
|
219 |
|
220 |
print(f"Output Embeddings Shape: {output_emb.shape}")
|
221 |
print(f"Output Raw Shape: {output_raw.shape}")
|