File size: 7,038 Bytes
484e080
 
 
 
55cd839
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
020595f
484e080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import json
from collections import Counter
import contractions
import csv
import pandas as pd
from typing import Tuple, List, Optional
from my_model.config import dataset_config as config

class OKVQADatasetProcessor:
    """
    Processes the OKVQA dataset by loading, processing, and merging question and annotation data.

    Attributes:
        questions_file_path (str): Path to the questions JSON file.
        annotations_file_path (str): Path to the annotations JSON file.
        questions (List[dict]): Extracted list of question entries from the JSON file.
        annotations (List[dict]): Extracted list of annotation entries from the JSON file.
        df_questions (DataFrame): DataFrame holding the questions.
        df_answers (DataFrame): DataFrame holding the annotations.
        merged_df (Optional[DataFrame]): DataFrame resulting from merging questions and answers, initialized as None.
    """

    def __init__(self, questions_file_path: str, annotations_file_path: str) -> None:
        """
        Initializes the dataset processor with file paths and loads the data into DataFrames.

        Parameters:
            questions_file_path (str): The file path for the questions JSON file.
            annotations_file_path (str): The file path for the annotations JSON file.
        """

        self.questions_file_path = questions_file_path
        self.annotations_file_path = annotations_file_path
        self.questions, self.annotations = self.load_data_files()
        self.df_questions = pd.DataFrame(self.questions)
        self.df_answers = pd.DataFrame(self.annotations)
        self.merged_df = None

    
    def load_data_files(self) -> Tuple[List[dict], List[dict]]:
        """
        Loads the question and annotation data from JSON files.

        Returns:
            Tuple[List[dict], List[dict]]: A tuple containing lists of questions and annotations.
        """
        with open(self.questions_file_path, 'r') as file:
            data = json.load(file)
            questions = data['questions']

        with open(self.annotations_file_path, 'r') as file:
            data = json.load(file)
            annotations = data['annotations']

        return questions, annotations

    
    @staticmethod
    def find_most_frequent(my_list: List[str]) -> Optional[str]:
        """
        Determines the most frequent item in a list.

        Parameters:
            my_list (List[str]): The list from which to find the most frequent item.

        Returns:
            Optional[str]: The most frequent item or None if the list is empty.
        """
        if not my_list:
            return None
        counter = Counter(my_list)
        most_common = counter.most_common(1)
        return most_common[0][0]

    
    def merge_data(self) -> None:
        """
        Merges the question and answer DataFrames on a common key.

        This method sets the 'merged_df' attribute to the resulting DataFrame after merging
        'df_questions' and 'df_answers' on the 'question_id' field, which is assumed to be
        present in both DataFrames.
        """

        self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id'])

    def join_words_with_hyphen(self, sentence):
        
        return '-'.join(sentence.split())

        
    def process_answers(self) -> None:
        """
        Processes answers from merged DataFrame by extracting and identifying the most frequent answers.
        """
        if self.merged_df is not None:
            self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x])
            self.merged_df['processed_answers'] = self.merged_df['answers'].apply(
                lambda x: [ans['answer'] for ans in x])
            self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent)
            self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply(
                self.find_most_frequent)
            self.merged_df.drop(columns=['answers'], inplace=True)
        else:
            print("DataFrames have not been merged yet.")

        # Apply the function to the 'most_frequent_processed_answer' column
        self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply(
            self.join_words_with_hyphen)

    
    def get_processed_data(self) -> Optional[pd.DataFrame]:
        """
        Retrieves the processed DataFrame.

        Returns:
            Optional[pd.DataFrame]: The processed DataFrame or None if it is not available.
        """

        if self.merged_df is not None:
            return self.merged_df
        else:
            print("DataFrame is empty or not processed yet.")
            return None

    
    def save_to_csv(self, df: pd.DataFrame, saved_file_name: Optional[str]) -> None:
        """
        Saves the DataFrame to a CSV file.

        Parameters:
            df (pd.DataFrame): The DataFrame to save.
            saved_file_name (Optional[str]): The target file name or path.
        """

        if saved_file_name is not None:
            if ".csv" not in saved_file_name:
                df.to_csv(os.path.join(saved_file_name, ".csv"), index=None)
            else:
                df.to_csv(saved_file_name, index=None)
        else:
            df.to_csv("data.csv", index=None)

    
    def display_dataframe(self) -> None:
        """
        Displays the processed DataFrame.
        """
        if self.merged_df is not None:
            print(self.merged_df)
        else:
            print("DataFrame is empty.")



def process_okvqa_dataset(questions_file_path: str, annotations_file_path: str, save_to_csv: bool = False, 
                          saved_file_name: Optional[str] = None) -> Optional[pd.DataFrame]:
    """
    Orchestrates the processing of the OK-VQA dataset using specified JSON file paths for questions and annotations.

    Parameters:
        questions_file_path (str): Path to the questions JSON file.
        annotations_file_path (str): Path to the annotations JSON file.
        save_to_csv (bool): Flag to determine if the processed data should be saved to CSV.
        saved_file_name (Optional[str]): Filename or path to save the CSV file. If None, defaults to 'data.csv'.

    Returns:
        Optional[pd.DataFrame]: The processed DataFrame containing merged and processed VQA data or None if empty.
    """
    # Initialize the dataset processor
    processor = OKVQADatasetProcessor(questions_file_path, annotations_file_path)

    # Merge question and answer data and process answers
    processor.merge_data()
    processor.process_answers()

    # Retrieve the processed DataFrame
    processed_data = processor.get_processed_data()

    # Optionally save the processed DataFrame to a CSV file
    if save_to_csv and processed_data is not None:
        processor.save_to_csv(processed_data, saved_file_name)

    return processed_data