from transformers import TapasTokenizer, TapasForQuestionAnswering import pandas as pd from typing import List, Dict from src.constants import id2aggregation def infer(query: str, file_name: str, model_name: str="google/tapas-base-finetuned-wtq") -> Dict[str, str]: # Load the file table = pd.read_csv(file_name, delimiter=",") table = table.astype(str) # Load the model model = TapasForQuestionAnswering.from_pretrained(model_name) tokenizer = TapasTokenizer.from_pretrained(model_name) # Make predictions queries = [query] inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt") outputs = model(**inputs) predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach() ) # predicted_answer_coordinates: contains coordinates for the respective answer cells, predicted_aggregation_indices: contains the aggregation type for each query aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices] answers = [] for coordinates in predicted_answer_coordinates: if len(coordinates) == 1: # only a single cell: answers.append(table.iat[coordinates[0]]) else: # multiple cells cell_values = [] for coordinate in coordinates: cell_values.append(table.iat[coordinate]) answers.append(", ".join(cell_values)) # Create the answer string answer_str = "" for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string): if predicted_agg == "NONE": answer_str = answer else: answer_str = f"{predicted_agg} : {answer}" return { "query": query, "answer": answer_str }, table