Meena commited on
Commit
1ce01d1
1 Parent(s): 6e9a986

Update app/tapas.py

Browse files
Files changed (1) hide show
  1. app/tapas.py +48 -0
app/tapas.py CHANGED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TapasTokenizer, TFTapasForQuestionAnswering
2
+ import pandas as pd
3
+
4
+ def execute_query(query, csv_file):
5
+
6
+ table = pd.read_csv(csv_file.name, delimiter=",")
7
+ table.fillna(0, inplace=True)
8
+ table = table.astype(str)
9
+
10
+ model_name = "google/tapas-base-finetuned-wtq"
11
+ model = TFTapasForQuestionAnswering.from_pretrained(model_name)
12
+ tokenizer = TapasTokenizer.from_pretrained(model_name)
13
+
14
+ queries = [query]
15
+
16
+ inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="tf")
17
+ outputs = model(**inputs)
18
+
19
+ predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
20
+ inputs, outputs.logits, outputs.logits_aggregation
21
+ )
22
+
23
+ # let's print out the results:
24
+ id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}
25
+ aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
26
+
27
+ answers = []
28
+ for coordinates in predicted_answer_coordinates:
29
+ if len(coordinates) == 1:
30
+ # only a single cell:
31
+ answers.append(table.iat[coordinates[0]])
32
+ else:
33
+ # multiple cells
34
+ cell_values = []
35
+ for coordinate in coordinates:
36
+ cell_values.append(table.iat[coordinate])
37
+ answers.append(cell_values)
38
+
39
+ for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
40
+ if predicted_agg != "NONE":
41
+ answers.append(predicted_agg)
42
+
43
+ query_result = {
44
+ "query": query,
45
+ "result": answers
46
+ }
47
+
48
+ return query_result, table