Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """基本方法 | |
| 创建中文数字系统 方法 | |
| 中文字符串 <=> 数字串 方法 | |
| 数字串 <=> 中文字符串 方法 | |
| """ | |
| __author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>" | |
| __data__ = "2019-05-02" | |
| from fish_speech.text.chn_text_norm.basic_class import * | |
| from fish_speech.text.chn_text_norm.basic_constant import * | |
| def create_system(numbering_type=NUMBERING_TYPES[1]): | |
| """ | |
| 根据数字系统类型返回创建相应的数字系统,默认为 mid | |
| NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 | |
| low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. | |
| mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. | |
| high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. | |
| 返回对应的数字系统 | |
| """ | |
| # chinese number units of '亿' and larger | |
| all_larger_units = zip( | |
| LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, | |
| LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, | |
| ) | |
| larger_units = [ | |
| CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) | |
| ] | |
| # chinese number units of '十, 百, 千, 万' | |
| all_smaller_units = zip( | |
| SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, | |
| SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, | |
| ) | |
| smaller_units = [ | |
| CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) | |
| ] | |
| # digis | |
| chinese_digis = zip( | |
| CHINESE_DIGIS, | |
| CHINESE_DIGIS, | |
| BIG_CHINESE_DIGIS_SIMPLIFIED, | |
| BIG_CHINESE_DIGIS_TRADITIONAL, | |
| ) | |
| digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] | |
| digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT | |
| digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT | |
| digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] | |
| # symbols | |
| positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) | |
| negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) | |
| point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) | |
| # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) | |
| system = NumberSystem() | |
| system.units = smaller_units + larger_units | |
| system.digits = digits | |
| system.math = MathSymbol(positive_cn, negative_cn, point_cn) | |
| # system.symbols = OtherSymbol(sil_cn) | |
| return system | |
| def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): | |
| def get_symbol(char, system): | |
| for u in system.units: | |
| if char in [u.traditional, u.simplified, u.big_s, u.big_t]: | |
| return u | |
| for d in system.digits: | |
| if char in [ | |
| d.traditional, | |
| d.simplified, | |
| d.big_s, | |
| d.big_t, | |
| d.alt_s, | |
| d.alt_t, | |
| ]: | |
| return d | |
| for m in system.math: | |
| if char in [m.traditional, m.simplified]: | |
| return m | |
| def string2symbols(chinese_string, system): | |
| int_string, dec_string = chinese_string, "" | |
| for p in [system.math.point.simplified, system.math.point.traditional]: | |
| if p in chinese_string: | |
| int_string, dec_string = chinese_string.split(p) | |
| break | |
| return [get_symbol(c, system) for c in int_string], [ | |
| get_symbol(c, system) for c in dec_string | |
| ] | |
| def correct_symbols(integer_symbols, system): | |
| """ | |
| 一百八 to 一百八十 | |
| 一亿一千三百万 to 一亿 一千万 三百万 | |
| """ | |
| if integer_symbols and isinstance(integer_symbols[0], CNU): | |
| if integer_symbols[0].power == 1: | |
| integer_symbols = [system.digits[1]] + integer_symbols | |
| if len(integer_symbols) > 1: | |
| if isinstance(integer_symbols[-1], CND) and isinstance( | |
| integer_symbols[-2], CNU | |
| ): | |
| integer_symbols.append( | |
| CNU(integer_symbols[-2].power - 1, None, None, None, None) | |
| ) | |
| result = [] | |
| unit_count = 0 | |
| for s in integer_symbols: | |
| if isinstance(s, CND): | |
| result.append(s) | |
| unit_count = 0 | |
| elif isinstance(s, CNU): | |
| current_unit = CNU(s.power, None, None, None, None) | |
| unit_count += 1 | |
| if unit_count == 1: | |
| result.append(current_unit) | |
| elif unit_count > 1: | |
| for i in range(len(result)): | |
| if ( | |
| isinstance(result[-i - 1], CNU) | |
| and result[-i - 1].power < current_unit.power | |
| ): | |
| result[-i - 1] = CNU( | |
| result[-i - 1].power + current_unit.power, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| return result | |
| def compute_value(integer_symbols): | |
| """ | |
| Compute the value. | |
| When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. | |
| e.g. '两千万' = 2000 * 10000 not 2000 + 10000 | |
| """ | |
| value = [0] | |
| last_power = 0 | |
| for s in integer_symbols: | |
| if isinstance(s, CND): | |
| value[-1] = s.value | |
| elif isinstance(s, CNU): | |
| value[-1] *= pow(10, s.power) | |
| if s.power > last_power: | |
| value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) | |
| last_power = s.power | |
| value.append(0) | |
| return sum(value) | |
| system = create_system(numbering_type) | |
| int_part, dec_part = string2symbols(chinese_string, system) | |
| int_part = correct_symbols(int_part, system) | |
| int_str = str(compute_value(int_part)) | |
| dec_str = "".join([str(d.value) for d in dec_part]) | |
| if dec_part: | |
| return "{0}.{1}".format(int_str, dec_str) | |
| else: | |
| return int_str | |
| def num2chn( | |
| number_string, | |
| numbering_type=NUMBERING_TYPES[1], | |
| big=False, | |
| traditional=False, | |
| alt_zero=False, | |
| alt_one=False, | |
| alt_two=True, | |
| use_zeros=True, | |
| use_units=True, | |
| ): | |
| def get_value(value_string, use_zeros=True): | |
| striped_string = value_string.lstrip("0") | |
| # record nothing if all zeros | |
| if not striped_string: | |
| return [] | |
| # record one digits | |
| elif len(striped_string) == 1: | |
| if use_zeros and len(value_string) != len(striped_string): | |
| return [system.digits[0], system.digits[int(striped_string)]] | |
| else: | |
| return [system.digits[int(striped_string)]] | |
| # recursively record multiple digits | |
| else: | |
| result_unit = next( | |
| u for u in reversed(system.units) if u.power < len(striped_string) | |
| ) | |
| result_string = value_string[: -result_unit.power] | |
| return ( | |
| get_value(result_string) | |
| + [result_unit] | |
| + get_value(striped_string[-result_unit.power :]) | |
| ) | |
| system = create_system(numbering_type) | |
| int_dec = number_string.split(".") | |
| if len(int_dec) == 1: | |
| int_string = int_dec[0] | |
| dec_string = "" | |
| elif len(int_dec) == 2: | |
| int_string = int_dec[0] | |
| dec_string = int_dec[1] | |
| else: | |
| raise ValueError( | |
| "invalid input num string with more than one dot: {}".format(number_string) | |
| ) | |
| if use_units and len(int_string) > 1: | |
| result_symbols = get_value(int_string) | |
| else: | |
| result_symbols = [system.digits[int(c)] for c in int_string] | |
| dec_symbols = [system.digits[int(c)] for c in dec_string] | |
| if dec_string: | |
| result_symbols += [system.math.point] + dec_symbols | |
| if alt_two: | |
| liang = CND( | |
| 2, | |
| system.digits[2].alt_s, | |
| system.digits[2].alt_t, | |
| system.digits[2].big_s, | |
| system.digits[2].big_t, | |
| ) | |
| for i, v in enumerate(result_symbols): | |
| if isinstance(v, CND) and v.value == 2: | |
| next_symbol = ( | |
| result_symbols[i + 1] if i < len(result_symbols) - 1 else None | |
| ) | |
| previous_symbol = result_symbols[i - 1] if i > 0 else None | |
| if isinstance(next_symbol, CNU) and isinstance( | |
| previous_symbol, (CNU, type(None)) | |
| ): | |
| if next_symbol.power != 1 and ( | |
| (previous_symbol is None) or (previous_symbol.power != 1) | |
| ): | |
| result_symbols[i] = liang | |
| # if big is True, '两' will not be used and `alt_two` has no impact on output | |
| if big: | |
| attr_name = "big_" | |
| if traditional: | |
| attr_name += "t" | |
| else: | |
| attr_name += "s" | |
| else: | |
| if traditional: | |
| attr_name = "traditional" | |
| else: | |
| attr_name = "simplified" | |
| result = "".join([getattr(s, attr_name) for s in result_symbols]) | |
| # if not use_zeros: | |
| # result = result.strip(getattr(system.digits[0], attr_name)) | |
| if alt_zero: | |
| result = result.replace( | |
| getattr(system.digits[0], attr_name), system.digits[0].alt_s | |
| ) | |
| if alt_one: | |
| result = result.replace( | |
| getattr(system.digits[1], attr_name), system.digits[1].alt_s | |
| ) | |
| for i, p in enumerate(POINT): | |
| if result.startswith(p): | |
| return CHINESE_DIGIS[0] + result | |
| # ^10, 11, .., 19 | |
| if ( | |
| len(result) >= 2 | |
| and result[1] | |
| in [ | |
| SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], | |
| SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], | |
| ] | |
| and result[0] | |
| in [ | |
| CHINESE_DIGIS[1], | |
| BIG_CHINESE_DIGIS_SIMPLIFIED[1], | |
| BIG_CHINESE_DIGIS_TRADITIONAL[1], | |
| ] | |
| ): | |
| result = result[1:] | |
| return result | |
| if __name__ == "__main__": | |
| # 测试程序 | |
| all_chinese_number_string = ( | |
| CHINESE_DIGIS | |
| + BIG_CHINESE_DIGIS_SIMPLIFIED | |
| + BIG_CHINESE_DIGIS_TRADITIONAL | |
| + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED | |
| + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL | |
| + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED | |
| + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL | |
| + ZERO_ALT | |
| + ONE_ALT | |
| + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) | |
| ) | |
| print("num:", chn2num("一万零四百零三点八零五")) | |
| print("num:", chn2num("一亿六点三")) | |
| print("num:", chn2num("一亿零六点三")) | |
| print("num:", chn2num("两千零一亿六点三")) | |
| # print('num:', chn2num('一零零八六')) | |
| print("txt:", num2chn("10260.03", alt_zero=True)) | |
| print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) | |
| print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) | |
| print( | |
| "txt:", | |
| num2chn( | |
| "059523810880", | |
| alt_one=True, | |
| alt_two=False, | |
| use_lzeros=True, | |
| use_rzeros=True, | |
| use_units=False, | |
| ), | |
| ) | |
| print(all_chinese_number_string) | |