KB-VQA-E / my_model /dataset /dataset_processor.py
m7mdal7aj's picture
Update my_model/dataset/dataset_processor.py
55cd839 verified
raw
history blame
No virus
7.04 kB
import json
from collections import Counter
import contractions
import csv
import pandas as pd
from typing import Tuple, List, Optional
from my_model.config import dataset_config as config
class OKVQADatasetProcessor:
"""
Processes the OKVQA dataset by loading, processing, and merging question and annotation data.
Attributes:
questions_file_path (str): Path to the questions JSON file.
annotations_file_path (str): Path to the annotations JSON file.
questions (List[dict]): Extracted list of question entries from the JSON file.
annotations (List[dict]): Extracted list of annotation entries from the JSON file.
df_questions (DataFrame): DataFrame holding the questions.
df_answers (DataFrame): DataFrame holding the annotations.
merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None.
"""
def __init__(self, questions_file_path: str, annotations_file_path: str) -> None:
"""
Initializes the dataset processor with file paths and loads the data into DataFrames.
Parameters:
questions_file_path (str): The file path for the questions JSON file.
annotations_file_path (str): The file path for the annotations JSON file.
"""
self.questions_file_path = questions_file_path
self.annotations_file_path = annotations_file_path
self.questions, self.annotations = self.load_data_files()
self.df_questions = pd.DataFrame(self.questions)
self.df_answers = pd.DataFrame(self.annotations)
self.merged_df = None
def load_data_files(self) -> Tuple[List[dict], List[dict]]:
"""
Loads the question and annotation data from JSON files.
Returns:
Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations.
"""
with open(self.questions_file_path, 'r') as file:
data = json.load(file)
questions = data['questions']
with open(self.annotations_file_path, 'r') as file:
data = json.load(file)
annotations = data['annotations']
return questions, annotations
@staticmethod
def find_most_frequent(my_list: List[str]) -> Optional[str]:
"""
Determines the most frequent item in a list.
Parameters:
my_list (List[str]): The list from which to find the most frequent item.
Returns:
Optional[str]: The most frequent item or None if the list is empty.
"""
if not my_list:
return None
counter = Counter(my_list)
most_common = counter.most_common(1)
return most_common[0][0]
def merge_data(self) -> None:
"""
Merges the question and answer DataFrames on a common key.
This method sets the 'merged_df' attribute to the resulting DataFrame after merging
'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be
present in both DataFrames.
"""
self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id'])
def join_words_with_hyphen(self, sentence):
return '-'.join(sentence.split())
def process_answers(self) -> None:
"""
Processes answers from merged DataFrame by extracting and identifying the most frequent answers.
"""
if self.merged_df is not None:
self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x])
self.merged_df['processed_answers'] = self.merged_df['answers'].apply(
lambda x: [ans['answer'] for ans in x])
self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent)
self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply(
self.find_most_frequent)
self.merged_df.drop(columns=['answers'], inplace=True)
else:
print("DataFrames have not been merged yet.")
# Apply the function to the 'most_frequent_processed_answer' column
self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply(
self.join_words_with_hyphen)
def get_processed_data(self) -> Optional[pd.DataFrame]:
"""
Retrieves the processed DataFrame.
Returns:
Optional[pd.DataFrame]: The processed DataFrame or None if it is not available.
"""
if self.merged_df is not None:
return self.merged_df
else:
print("DataFrame is empty or not processed yet.")
return None
def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None:
"""
Saves the DataFrame to a CSV file.
Parameters:
df (pd.DataFrame): The DataFrame to save.
saved_file_name (Optional[str]): The target file name or path.
"""
if saved_file_name is not None:
if ".csv" not in saved_file_name:
df.to_csv(os.path.join(saved_file_name, ".csv"), index=None)
else:
df.to_csv(saved_file_name, index=None)
else:
df.to_csv("data.csv", index=None)
def display_dataframe(self) -> None:
"""
Displays the processed DataFrame.
"""
if self.merged_df is not None:
print(self.merged_df)
else:
print("DataFrame is empty.")
def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False,
saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]:
"""
Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations.
Parameters:
questions_file_path (str): Path to the questions JSON file.
annotations_file_path (str): Path to the annotations JSON file.
save_to_csv (bool): Flag to determine if the processed data should be saved to CSV.
saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'.
Returns:
Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty.
"""
# Initialize the dataset processor
processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path)
# Merge question and answer data and process answers
processor.merge_data()
processor.process_answers()
# Retrieve the processed DataFrame
processed_data = processor.get_processed_data()
# Optionally save the processed DataFrame to a CSV file
if save_to_csv and processed_data is not None:
processor.save_to_csv(processed_data, saved_file_name)
return processed_data