# coding=utf-8 # Copyright 2022 The Microsoft, The Google and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import enum import functools import math import re # The following script is adapted from the script of TaPas. # Original: https://github.com/google-research/tapas/master/wikisql_utils.py from typing import Any, List, Text EMPTY_ANSWER = "none" EMPTY_ANSWER_AGG = "none" def _split_thousands(delimiter, value): split = value.split(delimiter) return len(split) > 1 and any((len(x) == 3 for x in split)) def convert_to_float(value): """Converts value to a float using a series of increasingly complex heuristics. Args: value: object that needs to be converted. Allowed types include float/int/strings. Returns: A float interpretation of value. Raises: ValueError if the float conversion of value fails. """ if isinstance(value, float): return value if isinstance(value, int): return float(value) if not isinstance(value, str): raise ValueError("Argument value is not a string. Can't parse it as float") sanitized = value try: # Example: 1,000.7 if "." in sanitized and "," in sanitized: return float(sanitized.replace(",", "")) # 1,000 if "," in sanitized and _split_thousands(",", sanitized): return float(sanitized.replace(",", "")) # 5,5556 if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(",", sanitized): return float(sanitized.replace(",", ".")) # 0.0.0.1 if sanitized.count(".") > 1: return float(sanitized.replace(".", "")) # 0,0,0,1 if sanitized.count(",") > 1: return float(sanitized.replace(",", "")) return float(sanitized) except ValueError: # Avoid adding the sanitized value in the error message. raise ValueError("Unable to convert value to float") def _normalize_float(answer): if answer is None: return None try: value = convert_to_float(answer) if isinstance(value, float) and math.isnan(value): return None return value except ValueError: return answer.lower() _TYPE_CONVERTER = { "text": lambda x: x, "real": convert_to_float, } class _Aggregation(enum.Enum): """Aggregations as defined by WikiSQL. Indexes match the data.""" NONE = 0 MAX = 1 MIN = 2 COUNT = 3 SUM = 4 AVERAGE = 5 class _Operator(enum.Enum): """The boolean operators used by WikiSQL. Indexes match the data.""" EQUALS = 0 GREATER = 1 LESSER = 2 @dataclasses.dataclass class _Condition: """Represents an SQL where clauses (e.g A = "a" or B > 5).""" column: Text operator: _Operator cmp_value: Any _TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL) def _normalize_for_match(x): return list(_TOKENIZER.findall(x.lower())) def _compare(operator, src, tgt): if operator == _Operator.EQUALS: return src == tgt elif operator == _Operator.GREATER: return src > tgt elif operator == _Operator.LESSER: return src < tgt raise ValueError(f"Unknown operator: {operator}") def _parse_value(table, column, cell_value): """Convert numeric values to floats and keeps everything else as string.""" types = table["types"] return _TYPE_CONVERTER[types[column]](cell_value) def _is_string(x): return isinstance(x, str) def _respect_conditions(table, row, conditions): """True if 'row' satisfies all 'conditions'.""" for cond in conditions: table_value = row[cond.column] cmp_value = _parse_value(table, cond.column, cond.cmp_value) if _is_string(table_value) and _is_string(cmp_value): table_value = _normalize_for_match(table_value) cmp_value = _normalize_for_match(cmp_value) if not isinstance(table_value, type(cmp_value)): raise ValueError("Type difference {} != {}".format(type(table_value), type(cmp_value))) if not _compare(cond.operator, table_value, cmp_value): return False return True def _get_float_answer(table, answer_coordinates, aggregation_op): """Applies operation to produce reference float answer.""" if not answer_coordinates: if aggregation_op == _Aggregation.COUNT: return 0.0 else: return EMPTY_ANSWER_AGG # Count can support non numeric answers. if aggregation_op == _Aggregation.COUNT: return float(len(answer_coordinates)) # If we have just one answer, if float returns it or try a conversion. values = [table["rows"][i][j] for (i, j) in answer_coordinates] if len(answer_coordinates) == 1: try: return convert_to_float(values[0]) except ValueError as e: if aggregation_op != _Aggregation.NONE: raise e if aggregation_op == _Aggregation.NONE: return None # Other aggregation only support numeric values. Bail out if we have strings. if not all((isinstance(v, (int, float)) for v in values)): return None if aggregation_op == _Aggregation.SUM: return float(sum(values)) elif aggregation_op == _Aggregation.AVERAGE: return sum(values) / len(answer_coordinates) else: raise ValueError(f"Unknown aggregation: {aggregation_op}") def _get_answer_coordinates(table, sql_query): """Retrieves references coordinates by executing SQL.""" # MAX and MIN are automatically supported by the model. aggregation_op_index = sql_query["agg"] if aggregation_op_index >= 3: aggregation_op = _Aggregation(aggregation_op_index) else: aggregation_op = _Aggregation.NONE target_column = sql_query["sel"] conditions = [ _Condition(column, _Operator(operator), cmp_value) for column, operator, cmp_value in zip( sql_query["conds"]["column_index"], sql_query["conds"]["operator_index"], sql_query["conds"]["condition"] ) ] indices = [] for row in range(len(table["rows"])): if _respect_conditions(table, table["rows"][row], conditions): indices.append((row, target_column)) if not indices: return [], aggregation_op if len(indices) == 1: return indices, aggregation_op # Parsing of MIN/MAX. if aggregation_op_index in (1, 2): operators = {2: min, 1: max} values = [(table["rows"][i][j], index) for index, (i, j) in enumerate(indices)] reduced = functools.reduce(operators[sql_query["agg"]], values) ret = [indices[reduced[1]]] return ret, _Aggregation.NONE return indices, aggregation_op def _get_answer_text(table, answer_coordinates, float_answer): if float_answer is not None: return [str(float_answer)] return [str(table["real_rows"][r][c]) for r, c in answer_coordinates] def retrieve_wikisql_query_answer_tapas(table, example) -> List: answer_coordinates, aggregation_op = _get_answer_coordinates(table, example) float_answer = _get_float_answer(table, answer_coordinates, aggregation_op) answer_text = _get_answer_text(table, answer_coordinates, float_answer) # keep the original data the same with TaPas if len(answer_text) == 0: answer_text = [EMPTY_ANSWER] return answer_text