Corey Morris
commited on
Commit
·
9549fcc
1
Parent(s):
85667d0
Added test for removal of undesired columns. fixed code error in column removal
Browse files- result_data_processor.py +1 -2
- test_data_processing.py +11 -1
result_data_processor.py
CHANGED
@@ -83,8 +83,7 @@ class ResultDataProcessor:
|
|
83 |
data = data[cols]
|
84 |
|
85 |
# Drop specific columns
|
86 |
-
data.drop(columns=['all', 'truthfulqa:mc|0'])
|
87 |
-
|
88 |
|
89 |
# Add parameter count column using extract_parameters function
|
90 |
data['Parameters'] = data.index.to_series().apply(self._extract_parameters)
|
|
|
83 |
data = data[cols]
|
84 |
|
85 |
# Drop specific columns
|
86 |
+
data = data.drop(columns=['all', 'truthfulqa:mc|0'])
|
|
|
87 |
|
88 |
# Add parameter count column using extract_parameters function
|
89 |
data['Parameters'] = data.index.to_series().apply(self._extract_parameters)
|
test_data_processing.py
CHANGED
@@ -18,12 +18,22 @@ class TestResultDataProcessor(unittest.TestCase):
|
|
18 |
self.assertIn('Parameters', data.columns)
|
19 |
self.assertIn('MMLU_average', data.columns)
|
20 |
# check number of columns
|
21 |
-
self.assertEqual(len(data.columns),
|
22 |
|
23 |
# check that the number of rows is correct
|
24 |
def test_rows(self):
|
25 |
data = self.processor.data
|
26 |
self.assertEqual(len(data), 992)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
if __name__ == '__main__':
|
29 |
unittest.main()
|
|
|
18 |
self.assertIn('Parameters', data.columns)
|
19 |
self.assertIn('MMLU_average', data.columns)
|
20 |
# check number of columns
|
21 |
+
self.assertEqual(len(data.columns), 61)
|
22 |
|
23 |
# check that the number of rows is correct
|
24 |
def test_rows(self):
|
25 |
data = self.processor.data
|
26 |
self.assertEqual(len(data), 992)
|
27 |
+
|
28 |
+
# # check that mc1 column exists
|
29 |
+
# def test_mc1(self):
|
30 |
+
# data = self.processor.data
|
31 |
+
# self.assertIn('mc1', data.columns)
|
32 |
+
|
33 |
+
# test that a column that contains truthfulqa:mc does not exist
|
34 |
+
def test_truthfulqa_mc(self):
|
35 |
+
data = self.processor.data
|
36 |
+
self.assertNotIn('truthfulqa:mc', data.columns)
|
37 |
|
38 |
if __name__ == '__main__':
|
39 |
unittest.main()
|