import gradio as gr import pandas as pd from transformers import TapexTokenizer, BartForConditionalGeneration, pipeline # Initialize TAPEX (Microsoft) model and tokenizer tokenizer_tapex = TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq") model_tapex = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq") # Initialize TAPAS (Google) models and pipelines pipe_tapas = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wtq") pipe_tapas2 = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wikisql-supervised") def chunk_dataframe(df, max_tokens=1024): """ Chunk a large dataframe into smaller pieces that fit within the token limit. For simplicity, we're assuming the number of rows determines the token count. """ chunk_size = max_tokens // len(df.columns) # Approximate number of rows that fit return [df[i:i+chunk_size] for i in range(0, len(df), chunk_size)] def process_table_query(query, table_data): """ Process a query and CSV data using TAPEX. """ # Convert all columns in the table to strings for TAPEX compatibility table_data = table_data.astype(str) # Chunk the table if it's too large chunks = chunk_dataframe(table_data) results = [] for chunk in chunks: # Microsoft TAPEX model (using TAPEX tokenizer and model) encoding = tokenizer_tapex(table=chunk, query=query, return_tensors="pt", max_length=1024, truncation=True) outputs = model_tapex.generate(**encoding) result_tapex = tokenizer_tapex.batch_decode(outputs, skip_special_tokens=True)[0] results.append(result_tapex) # Aggregate results (example: summing numerical values if the query is sum-related) try: # Convert the list of results to floats and sum them numerical_results = [float(r) for r in results if r.strip().isdigit()] total_sum = sum(numerical_results) return str(total_sum) except ValueError: # If the results are not numerical, return the joined string return ' '.join(results) def answer_query_from_csv(query, file): """ Function to handle file input and return model results. """ # Read the file into a DataFrame table_data = pd.read_csv(file) # Convert object-type columns to lowercase, ensuring only valid strings are affected for column in table_data.columns: if table_data[column].dtype == 'object': table_data[column] = table_data[column].apply(lambda x: x.lower() if isinstance(x, str) else x) # Extract year, month, day, and time components for datetime columns for column in table_data.columns: if pd.api.types.is_datetime64_any_dtype(table_data[column]): table_data[f'{column}_year'] = table_data[column].dt.year table_data[f'{column}_month'] = table_data[column].dt.month table_data[f'{column}_day'] = table_data[column] table_data[f'{column}_time'] = table_data[column].dt.strftime('%H:%M:%S') # Ensure all data in the table is converted to string table_data = table_data.astype(str) # Process the CSV file and query using TAPEX result_tapex = process_table_query(query, table_data) # Process the query using TAPAS pipelines (ensure all cells are strings) result_tapas = pipe_tapas(table=table_data, query=query)['cells'][0] result_tapas2 = pipe_tapas2(table=table_data, query=query)['cells'][0] return result_tapex, result_tapas, result_tapas2 # Create Gradio interface with gr.Blocks() as interface: gr.Markdown("# Table Question Answering with TAPEX and TAPAS Models") # Add a notice about the token limit gr.Markdown("### Note: Only the first 1024 tokens (query + table data) will be considered per chunk. If your table is too large, it will be chunked and processed separately.") # Two-column layout (input on the left, output on the right) with gr.Row(): with gr.Column(): # Input fields for the query and file query_input = gr.Textbox(label="Enter your query:") csv_input = gr.File(label="Upload your CSV file") with gr.Column(): # Output textboxes for the answers result_tapex = gr.Textbox(label="TAPEX Answer") result_tapas = gr.Textbox(label="TAPAS (WikiTableQuestions) Answer") result_tapas2 = gr.Textbox(label="TAPAS (WikiSQL) Answer") # Submit button submit_btn = gr.Button("Submit") # Action when submit button is clicked fn=answer_query_from_csv, inputs=[query_input, csv_input], outputs=[result_tapex, result_tapas, result_tapas2] ) # Launch the Gradio interface interface.launch(share=True)