m7mdal7aj commited on
Commit
484e080
1 Parent(s): 7854237

Create dataset_processor.py

Browse files
Files changed (1) hide show
  1. my_model/dataset/dataset_processor.py +176 -0
my_model/dataset/dataset_processor.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import Counter
3
+ import contractions
4
+ import csv
5
+ from typing import Tuple, List, Optional
6
+ from my_model.config import dataset_config as config
7
+
8
+ class OKVQADatasetProcessor:
9
+ """
10
+ Processes the OKVQA dataset by loading, processing, and merging question and annotation data.
11
+
12
+ Attributes:
13
+ questions_file_path (str): Path to the questions JSON file.
14
+ annotations_file_path (str): Path to the annotations JSON file.
15
+ questions (List[dict]): Extracted list of question entries from the JSON file.
16
+ annotations (List[dict]): Extracted list of annotation entries from the JSON file.
17
+ df_questions (DataFrame): DataFrame holding the questions.
18
+ df_answers (DataFrame): DataFrame holding the annotations.
19
+ merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None.
20
+ """
21
+
22
+ def __init__(self, questions_file_path: str, annotations_file_path: str) -> None:
23
+ """
24
+ Initializes the dataset processor with file paths and loads the data into DataFrames.
25
+
26
+ Parameters:
27
+ questions_file_path (str): The file path for the questions JSON file.
28
+ annotations_file_path (str): The file path for the annotations JSON file.
29
+ """
30
+
31
+ self.questions_file_path = questions_file_path
32
+ self.annotations_file_path = annotations_file_path
33
+ self.questions, self.annotations = self.load_data_files()
34
+ self.df_questions = pd.DataFrame(self.questions)
35
+ self.df_answers = pd.DataFrame(self.annotations)
36
+ self.merged_df = None
37
+
38
+ def load_data_files(self) -> Tuple[List[dict], List[dict]]:
39
+ """
40
+ Loads the question and annotation data from JSON files.
41
+
42
+ Returns:
43
+ Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations.
44
+ """
45
+ with open(self.questions_file_path, 'r') as file:
46
+ data = json.load(file)
47
+ questions = data['questions']
48
+
49
+ with open(self.annotations_file_path, 'r') as file:
50
+ data = json.load(file)
51
+ annotations = data['annotations']
52
+
53
+ return questions, annotations
54
+
55
+ @staticmethod
56
+ def find_most_frequent(my_list: List[str]) -> Optional[str]:
57
+ """
58
+ Determines the most frequent item in a list.
59
+
60
+ Parameters:
61
+ my_list (List[str]): The list from which to find the most frequent item.
62
+
63
+ Returns:
64
+ Optional[str]: The most frequent item or None if the list is empty.
65
+ """
66
+ if not my_list:
67
+ return None
68
+ counter = Counter(my_list)
69
+ most_common = counter.most_common(1)
70
+ return most_common[0][0]
71
+
72
+ def merge_data(self) -> None:
73
+ """
74
+ Merges the question and answer DataFrames on a common key.
75
+
76
+ This method sets the 'merged_df' attribute to the resulting DataFrame after merging
77
+ 'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be
78
+ present in both DataFrames.
79
+ """
80
+
81
+ self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id'])
82
+
83
+ def join_words_with_hyphen(self, sentence):
84
+
85
+ return '-'.join(sentence.split())
86
+
87
+ def process_answers(self) -> None:
88
+ """
89
+ Processes answers from merged DataFrame by extracting and identifying the most frequent answers.
90
+ """
91
+ if self.merged_df is not None:
92
+ self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x])
93
+ self.merged_df['processed_answers'] = self.merged_df['answers'].apply(
94
+ lambda x: [ans['answer'] for ans in x])
95
+ self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent)
96
+ self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply(
97
+ self.find_most_frequent)
98
+ self.merged_df.drop(columns=['answers'], inplace=True)
99
+ else:
100
+ print("DataFrames have not been merged yet.")
101
+
102
+ # Apply the function to the 'most_frequent_processed_answer' column
103
+ self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply(
104
+ self.join_words_with_hyphen)
105
+
106
+ def get_processed_data(self) -> Optional[pd.DataFrame]:
107
+ """
108
+ Retrieves the processed DataFrame.
109
+
110
+ Returns:
111
+ Optional[pd.DataFrame]: The processed DataFrame or None if it is not available.
112
+ """
113
+
114
+ if self.merged_df is not None:
115
+ return self.merged_df
116
+ else:
117
+ print("DataFrame is empty or not processed yet.")
118
+ return None
119
+
120
+ def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None:
121
+ """
122
+ Saves the DataFrame to a CSV file.
123
+
124
+ Parameters:
125
+ df (pd.DataFrame): The DataFrame to save.
126
+ saved_file_name (Optional[str]): The target file name or path.
127
+ """
128
+
129
+ if saved_file_name is not None:
130
+ if ".csv" not in saved_file_name:
131
+ df.to_csv(os.path.join(saved_file_name, ".csv"), index=None)
132
+ else:
133
+ df.to_csv(saved_file_name, index=None)
134
+ else:
135
+ df.to_csv("data.csv", index=None)
136
+
137
+ def display_dataframe(self) -> None:
138
+ """
139
+ Displays the processed DataFrame.
140
+ """
141
+ if self.merged_df is not None:
142
+ print(self.merged_df)
143
+ else:
144
+ print("DataFrame is empty.")
145
+
146
+
147
+
148
+ def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False,
149
+ saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]:
150
+ """
151
+ Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations.
152
+
153
+ Parameters:
154
+ questions_file_path (str): Path to the questions JSON file.
155
+ annotations_file_path (str): Path to the annotations JSON file.
156
+ save_to_csv (bool): Flag to determine if the processed data should be saved to CSV.
157
+ saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'.
158
+
159
+ Returns:
160
+ Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty.
161
+ """
162
+ # Initialize the dataset processor
163
+ processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path)
164
+
165
+ # Merge question and answer data and process answers
166
+ processor.merge_data()
167
+ processor.process_answers()
168
+
169
+ # Retrieve the processed DataFrame
170
+ processed_data = processor.get_processed_data()
171
+
172
+ # Optionally save the processed DataFrame to a CSV file
173
+ if save_to_csv and processed_data is not None:
174
+ processor.save_to_csv(processed_data, saved_file_name)
175
+
176
+ return processed_data