import re import torch import torch.nn.functional as F class CustomRepetitionPenaltyLogitsProcessorRepeat(): def __init__(self, penalty: float, max_input_ids, past_window): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self.past_window:] freq = F.one_hot(input_ids, scores.size(1)).sum(1) freq[self.max_input_ids:] = 0 alpha = self.penalty**freq scores = torch.where(scores < 0, scores*alpha, scores/alpha) return scores class CustomRepetitionPenaltyLogitsProcessor(): def __init__(self, penalty: float, max_input_ids, past_window): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self.past_window:] score = torch.gather(scores, 1, input_ids) _score = score.detach().clone() score = torch.where(score < 0, score * self.penalty, score / self.penalty) score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids] scores.scatter_(1, input_ids, score) return scores def count_invalid_characters(s): s = re.sub(r'\[uv_break\]|\[laugh\]|\[lbreak\]', '', s) pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。、,\. ]') non_alphabetic_chinese_chars = pattern.findall(s) return set(non_alphabetic_chinese_chars) def detect_language(sentence): chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') english_word_pattern = re.compile(r'\b[A-Za-z]+\b') chinese_chars = chinese_char_pattern.findall(sentence) english_words = english_word_pattern.findall(sentence) if len(chinese_chars) > len(english_words): return "zh" else: return "en" character_map = { ':': ',', ';': ',', '!': '。', '(': ',', ')': ',', '【': ',', '】': ',', '『': ',', '』': ',', '「': ',', '」': ',', '《': ',', '》': ',', '-': ',', '‘': '', '“': '', '’': '', '”': '', ':': ',', ';': ',', '!': '.', '(': ',', ')': ',', '[': ',', ']': ',', '>': ',', '<': ',', '-': ',', } halfwidth_2_fullwidth_map = { '!': '!', '"': '“', "'": '‘', '#': '#', '$': '$', '%': '%', '&': '&', '(': '(', ')': ')', ',': ',', '-': '-', '*': '*', '+': '+', '.': '。', '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', '@': '@', # '[': '[', '\\': '\', # ']': ']', '^': '^', # '_': '_', '`': '`', '{': '{', '|': '|', '}': '}', '~': '~' } def apply_half2full_map(text): translation_table = str.maketrans(halfwidth_2_fullwidth_map) return text.translate(translation_table) def apply_character_map(text): translation_table = str.maketrans(character_map) return text.translate(translation_table)