|
import json |
|
from collections import Counter |
|
import contractions |
|
import csv |
|
import pandas as pd |
|
from typing import Tuple, List, Optional |
|
from my_model.config import dataset_config as config |
|
|
|
class OKVQADatasetProcessor: |
|
""" |
|
Processes the OKVQA dataset by loading, processing, and merging question and annotation data. |
|
|
|
Attributes: |
|
questions_file_path (str): Path to the questions JSON file. |
|
annotations_file_path (str): Path to the annotations JSON file. |
|
questions (List[dict]): Extracted list of question entries from the JSON file. |
|
annotations (List[dict]): Extracted list of annotation entries from the JSON file. |
|
df_questions (DataFrame): DataFrame holding the questions. |
|
df_answers (DataFrame): DataFrame holding the annotations. |
|
merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None. |
|
""" |
|
|
|
def __init__(self, questions_file_path: str, annotations_file_path: str) -> None: |
|
""" |
|
Initializes the dataset processor with file paths and loads the data into DataFrames. |
|
|
|
Parameters: |
|
questions_file_path (str): The file path for the questions JSON file. |
|
annotations_file_path (str): The file path for the annotations JSON file. |
|
""" |
|
|
|
self.questions_file_path = questions_file_path |
|
self.annotations_file_path = annotations_file_path |
|
self.questions, self.annotations = self.load_data_files() |
|
self.df_questions = pd.DataFrame(self.questions) |
|
self.df_answers = pd.DataFrame(self.annotations) |
|
self.merged_df = None |
|
|
|
|
|
def load_data_files(self) -> Tuple[List[dict], List[dict]]: |
|
""" |
|
Loads the question and annotation data from JSON files. |
|
|
|
Returns: |
|
Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations. |
|
""" |
|
with open(self.questions_file_path, 'r') as file: |
|
data = json.load(file) |
|
questions = data['questions'] |
|
|
|
with open(self.annotations_file_path, 'r') as file: |
|
data = json.load(file) |
|
annotations = data['annotations'] |
|
|
|
return questions, annotations |
|
|
|
|
|
@staticmethod |
|
def find_most_frequent(my_list: List[str]) -> Optional[str]: |
|
""" |
|
Determines the most frequent item in a list. |
|
|
|
Parameters: |
|
my_list (List[str]): The list from which to find the most frequent item. |
|
|
|
Returns: |
|
Optional[str]: The most frequent item or None if the list is empty. |
|
""" |
|
if not my_list: |
|
return None |
|
counter = Counter(my_list) |
|
most_common = counter.most_common(1) |
|
return most_common[0][0] |
|
|
|
|
|
def merge_data(self) -> None: |
|
""" |
|
Merges the question and answer DataFrames on a common key. |
|
|
|
This method sets the 'merged_df' attribute to the resulting DataFrame after merging |
|
'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be |
|
present in both DataFrames. |
|
""" |
|
|
|
self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id']) |
|
|
|
def join_words_with_hyphen(self, sentence): |
|
|
|
return '-'.join(sentence.split()) |
|
|
|
|
|
def process_answers(self) -> None: |
|
""" |
|
Processes answers from merged DataFrame by extracting and identifying the most frequent answers. |
|
""" |
|
if self.merged_df is not None: |
|
self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x]) |
|
self.merged_df['processed_answers'] = self.merged_df['answers'].apply( |
|
lambda x: [ans['answer'] for ans in x]) |
|
self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent) |
|
self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply( |
|
self.find_most_frequent) |
|
self.merged_df.drop(columns=['answers'], inplace=True) |
|
else: |
|
print("DataFrames have not been merged yet.") |
|
|
|
|
|
self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply( |
|
self.join_words_with_hyphen) |
|
|
|
|
|
def get_processed_data(self) -> Optional[pd.DataFrame]: |
|
""" |
|
Retrieves the processed DataFrame. |
|
|
|
Returns: |
|
Optional[pd.DataFrame]: The processed DataFrame or None if it is not available. |
|
""" |
|
|
|
if self.merged_df is not None: |
|
return self.merged_df |
|
else: |
|
print("DataFrame is empty or not processed yet.") |
|
return None |
|
|
|
|
|
def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None: |
|
""" |
|
Saves the DataFrame to a CSV file. |
|
|
|
Parameters: |
|
df (pd.DataFrame): The DataFrame to save. |
|
saved_file_name (Optional[str]): The target file name or path. |
|
""" |
|
|
|
if saved_file_name is not None: |
|
if ".csv" not in saved_file_name: |
|
df.to_csv(os.path.join(saved_file_name, ".csv"), index=None) |
|
else: |
|
df.to_csv(saved_file_name, index=None) |
|
else: |
|
df.to_csv("data.csv", index=None) |
|
|
|
|
|
def display_dataframe(self) -> None: |
|
""" |
|
Displays the processed DataFrame. |
|
""" |
|
if self.merged_df is not None: |
|
print(self.merged_df) |
|
else: |
|
print("DataFrame is empty.") |
|
|
|
|
|
|
|
def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False, |
|
saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]: |
|
""" |
|
Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations. |
|
|
|
Parameters: |
|
questions_file_path (str): Path to the questions JSON file. |
|
annotations_file_path (str): Path to the annotations JSON file. |
|
save_to_csv (bool): Flag to determine if the processed data should be saved to CSV. |
|
saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'. |
|
|
|
Returns: |
|
Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty. |
|
""" |
|
|
|
processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path) |
|
|
|
|
|
processor.merge_data() |
|
processor.process_answers() |
|
|
|
|
|
processed_data = processor.get_processed_data() |
|
|
|
|
|
if save_to_csv and processed_data is not None: |
|
processor.save_to_csv(processed_data, saved_file_name) |
|
|
|
return processed_data |
|
|