Spaces:
Paused
Paused
| # coding=utf-8 | |
| # Copyright 2022 The Microsoft, The Google and The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import dataclasses | |
| import enum | |
| import functools | |
| import math | |
| import re | |
| # The following script is adapted from the script of TaPas. | |
| # Original: https://github.com/google-research/tapas/master/wikisql_utils.py | |
| from typing import Any, List, Text | |
| EMPTY_ANSWER = "none" | |
| EMPTY_ANSWER_AGG = "none" | |
| def _split_thousands(delimiter, value): | |
| split = value.split(delimiter) | |
| return len(split) > 1 and any((len(x) == 3 for x in split)) | |
| def convert_to_float(value): | |
| """Converts value to a float using a series of increasingly complex heuristics. | |
| Args: | |
| value: object that needs to be converted. Allowed types include | |
| float/int/strings. | |
| Returns: | |
| A float interpretation of value. | |
| Raises: | |
| ValueError if the float conversion of value fails. | |
| """ | |
| if isinstance(value, float): | |
| return value | |
| if isinstance(value, int): | |
| return float(value) | |
| if not isinstance(value, str): | |
| raise ValueError("Argument value is not a string. Can't parse it as float") | |
| sanitized = value | |
| try: | |
| # Example: 1,000.7 | |
| if "." in sanitized and "," in sanitized: | |
| return float(sanitized.replace(",", "")) | |
| # 1,000 | |
| if "," in sanitized and _split_thousands(",", sanitized): | |
| return float(sanitized.replace(",", "")) | |
| # 5,5556 | |
| if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(",", sanitized): | |
| return float(sanitized.replace(",", ".")) | |
| # 0.0.0.1 | |
| if sanitized.count(".") > 1: | |
| return float(sanitized.replace(".", "")) | |
| # 0,0,0,1 | |
| if sanitized.count(",") > 1: | |
| return float(sanitized.replace(",", "")) | |
| return float(sanitized) | |
| except ValueError: | |
| # Avoid adding the sanitized value in the error message. | |
| raise ValueError("Unable to convert value to float") | |
| def _normalize_float(answer): | |
| if answer is None: | |
| return None | |
| try: | |
| value = convert_to_float(answer) | |
| if isinstance(value, float) and math.isnan(value): | |
| return None | |
| return value | |
| except ValueError: | |
| return answer.lower() | |
| _TYPE_CONVERTER = { | |
| "text": lambda x: x, | |
| "real": convert_to_float, | |
| } | |
| class _Aggregation(enum.Enum): | |
| """Aggregations as defined by WikiSQL. Indexes match the data.""" | |
| NONE = 0 | |
| MAX = 1 | |
| MIN = 2 | |
| COUNT = 3 | |
| SUM = 4 | |
| AVERAGE = 5 | |
| class _Operator(enum.Enum): | |
| """The boolean operators used by WikiSQL. Indexes match the data.""" | |
| EQUALS = 0 | |
| GREATER = 1 | |
| LESSER = 2 | |
| class _Condition: | |
| """Represents an SQL where clauses (e.g A = "a" or B > 5).""" | |
| column: Text | |
| operator: _Operator | |
| cmp_value: Any | |
| _TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL) | |
| def _normalize_for_match(x): | |
| return list(_TOKENIZER.findall(x.lower())) | |
| def _compare(operator, src, tgt): | |
| if operator == _Operator.EQUALS: | |
| return src == tgt | |
| elif operator == _Operator.GREATER: | |
| return src > tgt | |
| elif operator == _Operator.LESSER: | |
| return src < tgt | |
| raise ValueError(f"Unknown operator: {operator}") | |
| def _parse_value(table, column, cell_value): | |
| """Convert numeric values to floats and keeps everything else as string.""" | |
| types = table["types"] | |
| return _TYPE_CONVERTER[types[column]](cell_value) | |
| def _is_string(x): | |
| return isinstance(x, str) | |
| def _respect_conditions(table, row, conditions): | |
| """True if 'row' satisfies all 'conditions'.""" | |
| for cond in conditions: | |
| table_value = row[cond.column] | |
| cmp_value = _parse_value(table, cond.column, cond.cmp_value) | |
| if _is_string(table_value) and _is_string(cmp_value): | |
| table_value = _normalize_for_match(table_value) | |
| cmp_value = _normalize_for_match(cmp_value) | |
| if not isinstance(table_value, type(cmp_value)): | |
| raise ValueError("Type difference {} != {}".format(type(table_value), type(cmp_value))) | |
| if not _compare(cond.operator, table_value, cmp_value): | |
| return False | |
| return True | |
| def _get_float_answer(table, answer_coordinates, aggregation_op): | |
| """Applies operation to produce reference float answer.""" | |
| if not answer_coordinates: | |
| if aggregation_op == _Aggregation.COUNT: | |
| return 0.0 | |
| else: | |
| return EMPTY_ANSWER_AGG | |
| # Count can support non numeric answers. | |
| if aggregation_op == _Aggregation.COUNT: | |
| return float(len(answer_coordinates)) | |
| # If we have just one answer, if float returns it or try a conversion. | |
| values = [table["rows"][i][j] for (i, j) in answer_coordinates] | |
| if len(answer_coordinates) == 1: | |
| try: | |
| return convert_to_float(values[0]) | |
| except ValueError as e: | |
| if aggregation_op != _Aggregation.NONE: | |
| raise e | |
| if aggregation_op == _Aggregation.NONE: | |
| return None | |
| # Other aggregation only support numeric values. Bail out if we have strings. | |
| if not all((isinstance(v, (int, float)) for v in values)): | |
| return None | |
| if aggregation_op == _Aggregation.SUM: | |
| return float(sum(values)) | |
| elif aggregation_op == _Aggregation.AVERAGE: | |
| return sum(values) / len(answer_coordinates) | |
| else: | |
| raise ValueError(f"Unknown aggregation: {aggregation_op}") | |
| def _get_answer_coordinates(table, sql_query): | |
| """Retrieves references coordinates by executing SQL.""" | |
| # MAX and MIN are automatically supported by the model. | |
| aggregation_op_index = sql_query["agg"] | |
| if aggregation_op_index >= 3: | |
| aggregation_op = _Aggregation(aggregation_op_index) | |
| else: | |
| aggregation_op = _Aggregation.NONE | |
| target_column = sql_query["sel"] | |
| conditions = [ | |
| _Condition(column, _Operator(operator), cmp_value) | |
| for column, operator, cmp_value in zip( | |
| sql_query["conds"]["column_index"], sql_query["conds"]["operator_index"], sql_query["conds"]["condition"] | |
| ) | |
| ] | |
| indices = [] | |
| for row in range(len(table["rows"])): | |
| if _respect_conditions(table, table["rows"][row], conditions): | |
| indices.append((row, target_column)) | |
| if not indices: | |
| return [], aggregation_op | |
| if len(indices) == 1: | |
| return indices, aggregation_op | |
| # Parsing of MIN/MAX. | |
| if aggregation_op_index in (1, 2): | |
| operators = {2: min, 1: max} | |
| values = [(table["rows"][i][j], index) for index, (i, j) in enumerate(indices)] | |
| reduced = functools.reduce(operators[sql_query["agg"]], values) | |
| ret = [indices[reduced[1]]] | |
| return ret, _Aggregation.NONE | |
| return indices, aggregation_op | |
| def _get_answer_text(table, answer_coordinates, float_answer): | |
| if float_answer is not None: | |
| return [str(float_answer)] | |
| return [str(table["real_rows"][r][c]) for r, c in answer_coordinates] | |
| def retrieve_wikisql_query_answer_tapas(table, example) -> List: | |
| answer_coordinates, aggregation_op = _get_answer_coordinates(table, example) | |
| float_answer = _get_float_answer(table, answer_coordinates, aggregation_op) | |
| answer_text = _get_answer_text(table, answer_coordinates, float_answer) | |
| # keep the original data the same with TaPas | |
| if len(answer_text) == 0: | |
| answer_text = [EMPTY_ANSWER] | |
| return answer_text | |