Spaces:
Running
Running
Upload 2 files
Browse files- app.py +65 -27
- new_test_saved_finetuned_model.py +62 -5
app.py
CHANGED
@@ -29,25 +29,37 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
29 |
# shutil.copyfile(label.name, saved_test_label)
|
30 |
# shutil.copyfile(info.name, saved_train_info)
|
31 |
parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
|
32 |
-
test_info_location=parent_location+"test_info.txt"
|
33 |
-
test_location=parent_location+"test.txt"
|
34 |
-
label_location=parent_location+"test_label.txt"
|
35 |
-
|
|
|
|
|
36 |
finetune_task="highGRschool10"
|
37 |
-
|
38 |
-
# test_location=parent_location+"fullTest/test.txt"
|
39 |
elif(model_name== "ASTRA-FT-LGR" ):
|
40 |
finetune_task="lowGRschoolAll"
|
41 |
-
# test_info_location=parent_location+"lowGRschoolAll/test_info.txt"
|
42 |
-
# test_location=parent_location+"lowGRschoolAll/test.txt"
|
43 |
elif(model_name=="ASTRA-FT-FULL"):
|
44 |
-
# test_info_location=parent_location+"fullTest/test_info.txt"
|
45 |
-
# test_location=parent_location+"fullTest/test.txt"
|
46 |
finetune_task="fullTest"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
else:
|
48 |
finetune_task=None
|
49 |
# Load the test_info file and the graduation rate file
|
50 |
test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
grad_rate_data = pd.DataFrame(pd.read_pickle('assests/school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
|
52 |
|
53 |
# Step 1: Extract unique school numbers from test_info
|
@@ -57,7 +69,7 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
57 |
schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
|
58 |
|
59 |
# Define a threshold for high and low graduation rates (adjust as needed)
|
60 |
-
grad_rate_threshold = 0.9
|
61 |
|
62 |
# Step 4: Divide schools into high and low graduation rate groups
|
63 |
high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
|
@@ -113,17 +125,19 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
113 |
'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
|
114 |
]
|
115 |
# Group data by opt_task1 and opt_task2 based on test_info[6]
|
116 |
-
|
|
|
117 |
progress(0.2, desc="Running fine-tuned models...")
|
118 |
print("finetuned task: ",finetune_task)
|
119 |
subprocess.run([
|
120 |
"python", "new_test_saved_finetuned_model.py",
|
121 |
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
|
|
|
122 |
"-finetune_task", finetune_task,
|
123 |
"-test_dataset_path","../../../../fileHandler/selected_rows.txt",
|
124 |
# "-test_label_path","../../../../train_label.txt",
|
125 |
-
"-finetuned_bert_classifier_checkpoint",
|
126 |
-
"
|
127 |
"-e",str(1),
|
128 |
"-b",str(1000)
|
129 |
])
|
@@ -132,6 +146,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
132 |
# Load tlb and plb
|
133 |
with open("fileHandler/tlabels_plabels.pkl", "rb") as f:
|
134 |
tlb, plb = pickle.load(f)
|
|
|
|
|
135 |
|
136 |
# Define function to filter and write CSV
|
137 |
def process_and_write_csv(filtered_data, filename):
|
@@ -152,20 +168,40 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
152 |
|
153 |
row_num = 1
|
154 |
for _, row in filtered_data.iterrows():
|
155 |
-
school, class_id, student_id, status, problem, _, time_zone, duration, attempts = row[:9]
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
for step in steps_data.split('\t'):
|
159 |
step_parts = step.split('-')
|
|
|
160 |
step_name = step_parts[0]
|
161 |
-
action = step_parts[1]
|
162 |
-
attempt = step_parts[2]
|
163 |
-
|
|
|
|
|
164 |
|
165 |
row_data = [
|
166 |
-
row_num, "", "", student_id, "",
|
167 |
-
|
168 |
-
|
|
|
169 |
]
|
170 |
writer.writerow(row_data)
|
171 |
row_num += 1
|
@@ -179,7 +215,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
179 |
|
180 |
# Filter the data
|
181 |
filtered_data = selected_test_info.iloc[matching_indices]
|
182 |
-
|
|
|
183 |
|
184 |
# Define filename dynamically
|
185 |
task_type_map = {0: "ER", 1: "ME"}
|
@@ -291,8 +328,7 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
291 |
data = file.readlines()
|
292 |
selected_data = [data[i] for i in indices if i < len(data)]
|
293 |
# Assuming test_info[7] is a list with ideal tasks for each instance
|
294 |
-
ideal_tasks = test_info[
|
295 |
-
|
296 |
# Initialize counters
|
297 |
task_counts = {
|
298 |
1: {"ER": 0, "ME": 0, "both": 0,"none":0},
|
@@ -665,7 +701,7 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
665 |
Model: {model_name}
|
666 |
---------------------------\n
|
667 |
Time Taken: {result['time_taken_from_start']:.2f} seconds
|
668 |
-
Number of schools sampled: {len(
|
669 |
Total number of instances from HGR schools : {len(high_indices)}
|
670 |
Total number of instances from LGR schools: {len(low_indices)}
|
671 |
|
@@ -741,7 +777,9 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
741 |
# List of models for the dropdown menu
|
742 |
|
743 |
# models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
|
744 |
-
models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
|
|
|
|
|
745 |
content = """
|
746 |
<h1 style="color: black;">A S T R A</h1>
|
747 |
<h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
|
|
|
29 |
# shutil.copyfile(label.name, saved_test_label)
|
30 |
# shutil.copyfile(info.name, saved_train_info)
|
31 |
parent_location="ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/"
|
32 |
+
test_info_location=parent_location+"overallTestData/test_info.txt"
|
33 |
+
test_location=parent_location+"overallTestData/test.txt"
|
34 |
+
label_location=parent_location+"overallTestData/test_label.txt"
|
35 |
+
# "ASTRA-FT-HGR-RANDOM10", "ASTRA-FT-FIRST10-WSKILLS", "ASTRA-FT-FIRST10-WTIME", "ASTRA-FT-FIRST10-WSKILLS_WTIME"
|
36 |
+
checkpoint = "ratio_proportion_change3_2223/sch_largest_100-coded/output/"
|
37 |
+
if(model_name=="ASTRA-FT-HGR-RANDOM10"):
|
38 |
finetune_task="highGRschool10"
|
39 |
+
checkpoint += "highGRschool10/bert_fine_tuned.model.ep42"
|
|
|
40 |
elif(model_name== "ASTRA-FT-LGR" ):
|
41 |
finetune_task="lowGRschoolAll"
|
|
|
|
|
42 |
elif(model_name=="ASTRA-FT-FULL"):
|
|
|
|
|
43 |
finetune_task="fullTest"
|
44 |
+
elif(model_name in ["ASTRA-FT-FIRST10-WSKILLS", "ASTRA-FT-FIRST10-WTIME", "ASTRA-FT-FIRST10-WSKILLS_WTIME"]):
|
45 |
+
finetune_task="first10"
|
46 |
+
if model_name == "ASTRA-FT-FIRST10-WSKILLS":
|
47 |
+
checkpoint += "first10/bert_fine_tuned.model.first10%.wskills.ep24"
|
48 |
+
elif model_name == "ASTRA-FT-FIRST10-WTIME":
|
49 |
+
checkpoint += "first10/bert_fine_tuned.model.first10%.wfaopttime.wttime.wttopttime.wttnoopttime.ep23"
|
50 |
+
elif model_name == "ASTRA-FT-FIRST10-WSKILLS_WTIME":
|
51 |
+
checkpoint += "first10/bert_fine_tuned.model.first10%.wskills.wfaopttime.wttime.wttopttime.wttnoopttime.ep40"
|
52 |
else:
|
53 |
finetune_task=None
|
54 |
# Load the test_info file and the graduation rate file
|
55 |
test_info = pd.read_csv(test_info_location, sep=',', header=None, engine='python')
|
56 |
+
def convert_etalon(x):
|
57 |
+
means_and_extremes = 1
|
58 |
+
if x.is_integer():
|
59 |
+
means_and_extremes = 0
|
60 |
+
return means_and_extremes
|
61 |
+
|
62 |
+
test_info[8] = test_info[7].apply(convert_etalon) # 7th column contains etalon of factor which decides the ER/ME problem type
|
63 |
grad_rate_data = pd.DataFrame(pd.read_pickle('assests/school_grduation_rate.pkl'),columns=['school_number','grad_rate']) # Load the grad_rate data
|
64 |
|
65 |
# Step 1: Extract unique school numbers from test_info
|
|
|
69 |
schools = grad_rate_data[grad_rate_data['school_number'].isin(unique_schools)]
|
70 |
|
71 |
# Define a threshold for high and low graduation rates (adjust as needed)
|
72 |
+
grad_rate_threshold = 0.8 #0.9
|
73 |
|
74 |
# Step 4: Divide schools into high and low graduation rate groups
|
75 |
high_grad_schools = schools[schools['grad_rate'] >= grad_rate_threshold]['school_number'].unique()
|
|
|
125 |
'high' if idx in high_indices else 'low' for idx in selected_rows_df2.index
|
126 |
]
|
127 |
# Group data by opt_task1 and opt_task2 based on test_info[6]
|
128 |
+
|
129 |
+
opt_task_groups = ['opt_task1' if test_info.loc[idx, 8] == 0 else 'opt_task2' for idx in selected_rows_df2.index]
|
130 |
progress(0.2, desc="Running fine-tuned models...")
|
131 |
print("finetuned task: ",finetune_task)
|
132 |
subprocess.run([
|
133 |
"python", "new_test_saved_finetuned_model.py",
|
134 |
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
|
135 |
+
"-model_name", model_name,
|
136 |
"-finetune_task", finetune_task,
|
137 |
"-test_dataset_path","../../../../fileHandler/selected_rows.txt",
|
138 |
# "-test_label_path","../../../../train_label.txt",
|
139 |
+
"-finetuned_bert_classifier_checkpoint", checkpoint,
|
140 |
+
"-s",str(128),
|
141 |
"-e",str(1),
|
142 |
"-b",str(1000)
|
143 |
])
|
|
|
146 |
# Load tlb and plb
|
147 |
with open("fileHandler/tlabels_plabels.pkl", "rb") as f:
|
148 |
tlb, plb = pickle.load(f)
|
149 |
+
print("t==p = 0: ", sum([t==p for t,p in zip(tlb, plb) if t==0]))
|
150 |
+
print("t==p = 1: ", sum([t==p for t,p in zip(tlb, plb) if t==1]))
|
151 |
|
152 |
# Define function to filter and write CSV
|
153 |
def process_and_write_csv(filtered_data, filename):
|
|
|
168 |
|
169 |
row_num = 1
|
170 |
for _, row in filtered_data.iterrows():
|
171 |
+
# school, class_id, student_id, status, problem, _, time_zone, duration, attempts = row[:9]
|
172 |
+
|
173 |
+
# sch_NPHBD11809,17,stu_CRJBA61379,GRADUATED,ratio_proportion_change3-134,[strategygame],1,4.0,4.0,10,
|
174 |
+
# PercentChange-Attempt-1-0-OK-1667479255281 NumeratorQuantity1-Attempt-1-0-JIT-1667479268893 NumeratorQuantity1-Attempt-2-0-ERROR-1667479284199 NumeratorQuantity1-Attempt-3-0-OK-1667479294890 DenominatorQuantity1-Attempt-1-0-OK-1667479298749 NumeratorQuantity2-Attempt-1-0-OK-1667479301999 OptionalTask_1-Attempt-1-0-OK-1667479304886 DenominatorFactor-Attempt-1-0-OK-1667479314566 NumeratorFactor-Attempt-1-0-OK-1667479315579 EquationAnswer-Attempt-1-0-OK-1667479323750 FinalAnswerDirection-Attempt-1-0-OK-1667479333439 FinalAnswer-Attempt-1-0-OK-1667479338185,
|
175 |
+
# 1,
|
176 |
+
# 0 0.999767840033168 0 0 0.999996274310286 0 0.321529253998353 0.999722748307354 0.999840947031115,
|
177 |
+
# 0 -0.0002057730279919623 0 0 -3.302306839980673e-06 0 -0.41429892410820995 -0.00022392554103201068 -0.00012846367037400164,
|
178 |
+
# 0 0.999767840033168 0 0 0 0 0 0 0,
|
179 |
+
# 1667479255281 1667479294890 1667479298749 1667479301999 1667479304886 1667479314566 1667479315579 1667479323750 1667479333439 1667479338185,
|
180 |
+
# 0 39609 3859 3250 2887 9680 1013 8171 9689 4746,
|
181 |
+
# 2887 9680 1013 8171,0 39609 3859 3250 9689 4746,
|
182 |
+
# 14435,
|
183 |
+
# 82904
|
184 |
+
school, prob_solved, student_id, status, problem, prob_type, opt_type, _, _ = row[:9]
|
185 |
+
steps_data = row[10]
|
186 |
+
|
187 |
+
# if row_num == 1:
|
188 |
+
# print(row)
|
189 |
|
190 |
for step in steps_data.split('\t'):
|
191 |
step_parts = step.split('-')
|
192 |
+
|
193 |
step_name = step_parts[0]
|
194 |
+
action = step_parts[1]
|
195 |
+
attempt = step_parts[2]
|
196 |
+
help_level = step_parts[3]
|
197 |
+
outcome = step_parts[4]
|
198 |
+
curr_time = step_parts[5]
|
199 |
|
200 |
row_data = [
|
201 |
+
row_num, "", "", student_id, "", curr_time, "", "", "", "", "",
|
202 |
+
"ratio_proportion_change3", problem, "", "", step_name, attempt, "", outcome, "", action, "",
|
203 |
+
"", "", help_level, "", "", "", "", "", "", "",
|
204 |
+
school, "", "", "", "", "", status, ""
|
205 |
]
|
206 |
writer.writerow(row_data)
|
207 |
row_num += 1
|
|
|
215 |
|
216 |
# Filter the data
|
217 |
filtered_data = selected_test_info.iloc[matching_indices]
|
218 |
+
# new data contains etalon instead of 0/1 for ER/ME
|
219 |
+
filtered_data = filtered_data[filtered_data[8] == task_type] # Ensure test_info[6] matches
|
220 |
|
221 |
# Define filename dynamically
|
222 |
task_type_map = {0: "ER", 1: "ME"}
|
|
|
328 |
data = file.readlines()
|
329 |
selected_data = [data[i] for i in indices if i < len(data)]
|
330 |
# Assuming test_info[7] is a list with ideal tasks for each instance
|
331 |
+
ideal_tasks = test_info[8] # A list where each element is either 1 or 2
|
|
|
332 |
# Initialize counters
|
333 |
task_counts = {
|
334 |
1: {"ER": 0, "ME": 0, "both": 0,"none":0},
|
|
|
701 |
Model: {model_name}
|
702 |
---------------------------\n
|
703 |
Time Taken: {result['time_taken_from_start']:.2f} seconds
|
704 |
+
Number of schools sampled: {len(random_schools)}
|
705 |
Total number of instances from HGR schools : {len(high_indices)}
|
706 |
Total number of instances from LGR schools: {len(low_indices)}
|
707 |
|
|
|
777 |
# List of models for the dropdown menu
|
778 |
|
779 |
# models = ["ASTRA-FT-HGR", "ASTRA-FT-LGR", "ASTRA-FT-FULL"]
|
780 |
+
# models = ["ASTRA-FT-HGR", "ASTRA-FT-FULL"]
|
781 |
+
models = ["ASTRA-FT-HGR-RANDOM10", "ASTRA-FT-FIRST10-WSKILLS", "ASTRA-FT-FIRST10-WTIME", "ASTRA-FT-FIRST10-WSKILLS_WTIME"]
|
782 |
+
|
783 |
content = """
|
784 |
<h1 style="color: black;">A S T R A</h1>
|
785 |
<h2 style="color: black;">An AI Model for Analyzing Math Strategies</h2>
|
new_test_saved_finetuned_model.py
CHANGED
@@ -6,7 +6,7 @@ from torch.optim import Adam
|
|
6 |
from torch.utils.data import DataLoader
|
7 |
import pickle
|
8 |
print("here1",os.getcwd())
|
9 |
-
from src.dataset import TokenizerDataset,
|
10 |
from src.vocab import Vocab
|
11 |
print("here3",os.getcwd())
|
12 |
from src.bert import BERT
|
@@ -19,6 +19,7 @@ import tqdm
|
|
19 |
import sys
|
20 |
import time
|
21 |
import numpy as np
|
|
|
22 |
|
23 |
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, roc_auc_score
|
24 |
import matplotlib.pyplot as plt
|
@@ -466,12 +467,59 @@ class BERTFineTuneCalibratedTrainer:
|
|
466 |
sys.stdout = sys.__stdout__
|
467 |
sys.stdout = sys.__stdout__
|
468 |
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
def train():
|
472 |
parser = argparse.ArgumentParser()
|
473 |
|
474 |
parser.add_argument('-workspace_name', type=str, default=None)
|
|
|
475 |
parser.add_argument('-code', type=str, default=None, help="folder for pretraining outputs and logs")
|
476 |
parser.add_argument('-finetune_task', type=str, default=None, help="folder inside finetuning")
|
477 |
parser.add_argument("-attention", type=bool, default=False, help="analyse attention scores")
|
@@ -559,10 +607,19 @@ def train():
|
|
559 |
vocab_obj.load_vocab()
|
560 |
print("Vocab Size: ", len(vocab_obj.vocab))
|
561 |
|
562 |
-
|
563 |
print("Testing using finetuned model......")
|
564 |
-
print("Loading Test Dataset", args.test_dataset_path)
|
565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
# test_dataset = TokenizerDatasetForCalibration(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
567 |
|
568 |
print("Creating Dataloader...")
|
|
|
6 |
from torch.utils.data import DataLoader
|
7 |
import pickle
|
8 |
print("here1",os.getcwd())
|
9 |
+
from src.dataset import TokenizerDataset, TokenizerwSkillsDataset, TokenizerwTimeDataset, TokenizerwSkillsTimeDataset
|
10 |
from src.vocab import Vocab
|
11 |
print("here3",os.getcwd())
|
12 |
from src.bert import BERT
|
|
|
19 |
import sys
|
20 |
import time
|
21 |
import numpy as np
|
22 |
+
from sklearn.preprocessing import QuantileTransformer
|
23 |
|
24 |
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, roc_auc_score
|
25 |
import matplotlib.pyplot as plt
|
|
|
467 |
sys.stdout = sys.__stdout__
|
468 |
sys.stdout = sys.__stdout__
|
469 |
|
470 |
+
def prepare_normalized_time_df():
|
471 |
+
faopt_time = []
|
472 |
+
total_time = []
|
473 |
+
nonopt_time = []
|
474 |
+
opt_time = []
|
475 |
+
school = []
|
476 |
+
student = []
|
477 |
+
progress = []
|
478 |
+
prob_id = []
|
479 |
+
|
480 |
+
with open("ratio_proportion_change3_2223/sch_largest_100-coded/finetuning/fullData/train_info.txt", "r") as f:
|
481 |
+
for line in f:
|
482 |
+
line = line.strip()
|
483 |
+
if line:
|
484 |
+
line = line.split(",")
|
485 |
+
sch = line[0]
|
486 |
+
school.append(sch)
|
487 |
+
stu = line[2]
|
488 |
+
student.append(stu)
|
489 |
+
status = line[3]
|
490 |
+
progress.append(status)
|
491 |
+
pid = line[4]
|
492 |
+
prob_id.append(pid)
|
493 |
+
total = float(line[-1])#/60000
|
494 |
+
faopt = float(line[-2])#/60000
|
495 |
+
nonopt = sum([float(i) for i in line[-3].split("\t")])
|
496 |
+
opt = sum([float(i) for i in line[-4].split("\t")])
|
497 |
+
faopt_time.append(faopt)
|
498 |
+
total_time.append(total)
|
499 |
+
nonopt_time.append(nonopt)
|
500 |
+
opt_time.append(opt)
|
501 |
+
|
502 |
+
df = pd.DataFrame({"school": school, "student": student, "progress": progress, "prob_id": prob_id,
|
503 |
+
"faopt_time": faopt_time, "total_time": total_time,
|
504 |
+
"nonopt_time": nonopt_time, "opt_time": opt_time})
|
505 |
+
for col in df.columns:
|
506 |
+
print(col, col.endswith('time'))
|
507 |
+
if col.endswith('time'): #col == "faopt_time" or col =="total_time":
|
508 |
+
num_df = df[col]
|
509 |
+
col_values = num_df.values.reshape(-1, 1)
|
510 |
+
nt = QuantileTransformer(output_distribution='normal')
|
511 |
+
col_values_norm = nt.fit_transform(col_values)
|
512 |
+
df[col] = col_values_norm
|
513 |
+
print(df[col].describe())
|
514 |
+
df.set_index(["school", "student", "progress", "prob_id"], inplace=True)
|
515 |
+
df.to_pickle("ratio_proportion_change3_2223/sch_largest_100-coded/time_info/full_data_normalized_time.pkl")
|
516 |
+
|
517 |
|
518 |
def train():
|
519 |
parser = argparse.ArgumentParser()
|
520 |
|
521 |
parser.add_argument('-workspace_name', type=str, default=None)
|
522 |
+
parser.add_argument('-model_name', type=str, default=None)
|
523 |
parser.add_argument('-code', type=str, default=None, help="folder for pretraining outputs and logs")
|
524 |
parser.add_argument('-finetune_task', type=str, default=None, help="folder inside finetuning")
|
525 |
parser.add_argument("-attention", type=bool, default=False, help="analyse attention scores")
|
|
|
607 |
vocab_obj.load_vocab()
|
608 |
print("Vocab Size: ", len(vocab_obj.vocab))
|
609 |
|
610 |
+
prepare_normalized_time_df()
|
611 |
print("Testing using finetuned model......")
|
612 |
+
print("Loading Test Dataset", args.test_dataset_path)
|
613 |
+
# "ASTRA-FT-HGR-RANDOM10", "ASTRA-FT-FIRST10-WSKILLS", "ASTRA-FT-FIRST10-WTIME", "ASTRA-FT-FIRST10-WSKILLS_WTIME"
|
614 |
+
# test_dataset = TokenizerDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
615 |
+
if args.model_name == "ASTRA-FT-HGR-RANDOM10":
|
616 |
+
test_dataset = TokenizerwSkillsDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
617 |
+
elif args.model_name == "ASTRA-FT-FIRST10-WSKILLS":
|
618 |
+
test_dataset = TokenizerwSkillsDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
619 |
+
elif args.model_name == "ASTRA-FT-FIRST10-WTIME":
|
620 |
+
test_dataset = TokenizerwTimeDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
621 |
+
elif args.model_name == "ASTRA-FT-FIRST10-WSKILLS_WTIME":
|
622 |
+
test_dataset = TokenizerwSkillsTimeDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
623 |
# test_dataset = TokenizerDatasetForCalibration(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
624 |
|
625 |
print("Creating Dataloader...")
|