Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import collections | |
from math import sqrt | |
from itertools import chain, tee | |
from functools import lru_cache | |
import scipy.stats | |
import torch | |
from tokenizers import Tokenizer | |
from transformers import LogitsProcessor | |
from normalizers import normalization_strategy_lookup | |
from alternative_prf_schemes import prf_lookup, seeding_scheme_lookup | |
class WatermarkBase: | |
def __init__( | |
self, | |
vocab: list[int] = None, | |
gamma: float = 0.25, | |
delta: float = 2.0, | |
seeding_scheme: str = "selfhash", | |
select_green_tokens: bool = True, | |
): | |
# 现在可能将 None 作为 seeding_scheme 传递,所以现在要修补 | |
if seeding_scheme is None: | |
seeding_scheme = "selfhash" | |
# 词汇设置 | |
self.vocab = vocab | |
self.vocab_size = len(vocab) | |
# 水印行为: | |
self.gamma = gamma | |
self.delta = delta | |
self.rng = None | |
self._initialize_seeding_scheme(seeding_scheme) | |
# 传统行为: | |
self.select_green_tokens = select_green_tokens | |
def _initialize_seeding_scheme(self, seeding_scheme: str) -> None: | |
"""从一个通俗的“公共”名称初始化种子策略的所有内部设置。""" | |
self.prf_type, self.context_width, self.self_salt, self.hash_key = seeding_scheme_lookup(seeding_scheme) | |
def _seed_rng(self, input_ids: torch.LongTensor) -> None: | |
"""从本地上下文种子 RNG。不进行批处理,因为我们使用的生成器(如 cuda.random)不进行批处理。""" | |
# 需要有足够的token来进行种子生成 | |
if input_ids.shape[-1] < self.context_width: | |
raise ValueError(f"seeding_scheme requires at least a {self.context_width} token prefix to seed the RNG.") | |
prf_key = prf_lookup[self.prf_type](input_ids[-self.context_width :], salt_key=self.hash_key) | |
self.rng.manual_seed(prf_key % (2**64 - 1)) # 防止溢出 | |
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor: | |
"""根据本地上下文宽度对rng进行种子处理,并使用这些信息在绿色列表上生成id""" | |
self._seed_rng(input_ids) | |
greenlist_size = int(self.vocab_size * self.gamma) | |
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng) | |
if self.select_green_tokens: # directly | |
greenlist_ids = vocab_permutation[:greenlist_size] # new | |
else: # 通过红色选择绿色 | |
greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior | |
return greenlist_ids | |
class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor): | |
"""LogitsProcessor 在管道中修改模型输出分数。可以在任何 HF 管道中使用它来修改分数以适应水印,但也可以作为一个独立的工具插入到在模型输出和下一个标记采样器之间生成分数的任何模型中。""" | |
def __init__(self, *args, store_spike_ents: bool = False, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.store_spike_ents = store_spike_ents | |
self.spike_entropies = None | |
if self.store_spike_ents: | |
self._init_spike_entropies() | |
def _init_spike_entropies(self): | |
alpha = torch.exp(torch.tensor(self.delta)).item() | |
gamma = self.gamma | |
self.z_value = ((1 - gamma) * (alpha - 1)) / (1 - gamma + (alpha * gamma)) | |
self.expected_gl_coef = (gamma * alpha) / (1 - gamma + (alpha * gamma)) | |
# 当bias 是 "infinite" 时候捕获溢出 | |
if alpha == torch.inf: | |
self.z_value = 1.0 | |
self.expected_gl_coef = 1.0 | |
def _get_spike_entropies(self): | |
spike_ents = [[] for _ in range(len(self.spike_entropies))] | |
for b_idx, ent_tensor_list in enumerate(self.spike_entropies): | |
for ent_tensor in ent_tensor_list: | |
spike_ents[b_idx].append(ent_tensor.item()) | |
return spike_ents | |
def _get_and_clear_stored_spike_ents(self): | |
spike_ents = self._get_spike_entropies() | |
self.spike_entropies = None | |
return spike_ents | |
def _compute_spike_entropy(self, scores): | |
# 预先计算z得分 | |
probs = scores.softmax(dim=-1) | |
denoms = 1 + (self.z_value * probs) | |
renormed_probs = probs / denoms | |
sum_renormed_probs = renormed_probs.sum() | |
return sum_renormed_probs | |
def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor: | |
green_tokens_mask = torch.zeros_like(scores, dtype=torch.bool) | |
for b_idx, greenlist in enumerate(greenlist_token_ids): | |
if len(greenlist) > 0: | |
green_tokens_mask[b_idx][greenlist] = True | |
return green_tokens_mask | |
def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor: | |
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias | |
return scores | |
def _score_rejection_sampling(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, tail_rule="fixed_compute") -> list[int]: | |
"""基于当前候选的下一个标记生成绿名单。如果需要,拒绝并继续。该方法不进行批处理。 | |
这只是算法3“鲁棒私有水印”的部分版本,因为它始终假设贪婪采样。它仍然(有点)可以对所有类型的采样进行工作,但效果较差。 | |
为了高效工作,此函数可以在处理分布尾部的规则之间切换。 | |
默认情况下不会公开这些规则。""" | |
sorted_scores, greedy_predictions = scores.sort(dim=-1, descending=True) | |
final_greenlist = [] | |
for idx, prediction_candidate in enumerate(greedy_predictions): | |
greenlist_ids = self._get_greenlist_ids(torch.cat([input_ids, prediction_candidate[None]], dim=0)) # add candidate to prefix | |
if prediction_candidate in greenlist_ids: # test for consistency | |
final_greenlist.append(prediction_candidate) | |
# 为了提高效率,以下是可选的提前停止规则 | |
if tail_rule == "fixed_score": | |
if sorted_scores[0] - sorted_scores[idx + 1] > self.delta: | |
break | |
elif tail_rule == "fixed_list_length": | |
if len(final_greenlist) == 10: | |
break | |
elif tail_rule == "fixed_compute": | |
if idx == 40: | |
break | |
else: | |
pass # do not break early | |
return torch.as_tensor(final_greenlist, device=input_ids.device) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
"""使用上一个词作为input_ids进行调用,并为下一个token打分。""" | |
self.rng = torch.Generator(device=input_ids.device) if self.rng is None else self.rng | |
# 注意,去掉这个批循环会很好,但当前,种子和分区操作都不是张量/向量化的,因此批中的每个序列都需要单独处理。 | |
list_of_greenlist_ids = [None for _ in input_ids] # Greenlists could differ in length | |
for b_idx, input_seq in enumerate(input_ids): | |
if self.self_salt: | |
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx]) | |
else: | |
greenlist_ids = self._get_greenlist_ids(input_seq) | |
list_of_greenlist_ids[b_idx] = greenlist_ids | |
# 计算和存储熵 | |
if self.store_spike_ents: | |
if self.spike_entropies is None: | |
self.spike_entropies = [[] for _ in range(input_ids.shape[0])] | |
self.spike_entropies[b_idx].append(self._compute_spike_entropy(scores[b_idx])) | |
green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=list_of_greenlist_ids) | |
scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta) | |
return scores | |
class WatermarkDetector(WatermarkBase): | |
"""这是用于检测所有使用 WatermarkLogitsProcessor 印记的水印的检测器。 | |
检测器需要给出在文本生成期间给出的完全相同的设置,以复制水印的生成绿名单并检测水印。 | |
这包括在文本生成期间使用的正确设备、正确的分词器、正确的 seeding_scheme 名称和参数(delta、gamma)。 | |
可选参数包括 | |
* normalizers ["unicode", "homoglyphs", "truecase"] -> 这些可以减轻生成文本中可能触发水印的修改。 | |
* ignore_repeated_ngrams -> 此选项将更改检测规则,只计算每个唯一 ngram 一次。 | |
* z_threshold -> 更改此阈值将更改检测器的灵敏度。 | |
""" | |
def __init__( | |
self, | |
*args, | |
device: torch.device = None, | |
tokenizer: Tokenizer = None, | |
z_threshold: float = 4.0, | |
normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"] | |
ignore_repeated_ngrams: bool = True, | |
**kwargs, | |
): | |
super().__init__(*args, **kwargs) | |
# 配置选项 | |
assert device, "Must pass device" | |
assert tokenizer, "Need an instance of the generating tokenizer to perform detection" | |
self.tokenizer = tokenizer | |
self.device = device | |
self.z_threshold = z_threshold | |
self.rng = torch.Generator(device=self.device) | |
self.normalizers = [] | |
for normalization_strategy in normalizers: | |
self.normalizers.append(normalization_strategy_lookup(normalization_strategy)) | |
self.ignore_repeated_ngrams = ignore_repeated_ngrams | |
def dummy_detect( | |
self, | |
return_prediction: bool = True, | |
return_scores: bool = True, | |
z_threshold: float = None, | |
return_num_tokens_scored: bool = True, | |
return_num_green_tokens: bool = True, | |
return_green_fraction: bool = True, | |
return_green_token_mask: bool = False, | |
return_all_window_scores: bool = False, | |
return_z_score: bool = True, | |
return_z_at_T: bool = True, | |
return_p_value: bool = True, | |
): | |
# HF-style 输出字典 | |
score_dict = dict() | |
if return_num_tokens_scored: | |
score_dict.update(dict(num_tokens_scored=float("nan"))) | |
if return_num_green_tokens: | |
score_dict.update(dict(num_green_tokens=float("nan"))) | |
if return_green_fraction: | |
score_dict.update(dict(green_fraction=float("nan"))) | |
if return_z_score: | |
score_dict.update(dict(z_score=float("nan"))) | |
if return_p_value: | |
z_score = score_dict.get("z_score") | |
if z_score is None: | |
z_score = float("nan") | |
score_dict.update(dict(p_value=float("nan"))) | |
if return_green_token_mask: | |
score_dict.update(dict(green_token_mask=[])) | |
if return_all_window_scores: | |
score_dict.update(dict(window_list=[])) | |
if return_z_at_T: | |
score_dict.update(dict(z_score_at_T=torch.tensor([]))) | |
output_dict = {} | |
if return_scores: | |
output_dict.update(score_dict) | |
# 如果通过return_prediction,则执行假设检验并返回结果 | |
if return_prediction: | |
z_threshold = z_threshold if z_threshold else self.z_threshold | |
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test" | |
output_dict["prediction"] = False | |
return output_dict | |
def _compute_z_score(self, observed_count, T): | |
# count是指绿色token的数量,T是token的总数 | |
expected_count = self.gamma | |
numer = observed_count - expected_count * T | |
denom = sqrt(T * expected_count * (1 - expected_count)) | |
z = numer / denom | |
return z | |
def _compute_p_value(self, z): | |
p_value = scipy.stats.norm.sf(z) | |
return p_value | |
def _get_ngram_score_cached(self, prefix: tuple[int], target: int): | |
"""缓存了re-seeding and sampling""" | |
# 需要小心处理, 理想情况下应在__getattribute__访问self.prof_type、self.text_width、self.self_salt、self.hash_key时重置 | |
greenlist_ids = self._get_greenlist_ids(torch.as_tensor(prefix, device=self.device)) | |
return True if target in greenlist_ids else False | |
def _score_ngrams_in_passage(self, input_ids: torch.Tensor): | |
"""核心功能是收集输入中的所有ngram并计算其水印""" | |
if len(input_ids) - self.context_width < 1: | |
raise ValueError( | |
f"Must have at least {1} token to score after " | |
f"the first min_prefix_len={self.context_width} tokens required by the seeding scheme." | |
) | |
# 计算文章中所有ngrams上下文的分数: | |
token_ngram_generator = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt) | |
frequencies_table = collections.Counter(token_ngram_generator) | |
ngram_to_watermark_lookup = {} | |
for idx, ngram_example in enumerate(frequencies_table.keys()): | |
prefix = ngram_example if self.self_salt else ngram_example[:-1] | |
target = ngram_example[-1] | |
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target) | |
return ngram_to_watermark_lookup, frequencies_table | |
def _get_green_at_T_booleans(self, input_ids, ngram_to_watermark_lookup) -> tuple[torch.Tensor]: | |
"""生成基于每个标记的绿色与红色二值列表,一个忽略重复n-gram的独立列表,以及一个用于在两种表示法之间转换的偏移量列表: | |
green_token_mask = green_token_mask_unique[offsets],除了在所有会被计算为重复的位置之外 | |
""" | |
green_token_mask, green_token_mask_unique, offsets = [], [], [] | |
used_ngrams = {} | |
unique_ngram_idx = 0 | |
ngram_examples = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt) | |
for idx, ngram_example in enumerate(ngram_examples): | |
green_token_mask.append(ngram_to_watermark_lookup[ngram_example]) | |
if self.ignore_repeated_ngrams: | |
if ngram_example in used_ngrams: | |
pass | |
else: | |
used_ngrams[ngram_example] = True | |
unique_ngram_idx += 1 | |
green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example]) | |
else: | |
green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example]) | |
unique_ngram_idx += 1 | |
offsets.append(unique_ngram_idx - 1) | |
return ( | |
torch.tensor(green_token_mask), | |
torch.tensor(green_token_mask_unique), | |
torch.tensor(offsets), | |
) | |
def _score_sequence( | |
self, | |
input_ids: torch.Tensor, | |
return_num_tokens_scored: bool = True, | |
return_num_green_tokens: bool = True, | |
return_green_fraction: bool = True, | |
return_green_token_mask: bool = False, | |
return_z_score: bool = True, | |
return_z_at_T: bool = True, | |
return_p_value: bool = True, | |
): | |
ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids) | |
green_token_mask, green_unique, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup) | |
# 把所有ngrams的分数加起来 | |
if self.ignore_repeated_ngrams: | |
# 一个方法,只对每个唯一的n-gram计算一次绿色/红色命中。 | |
# 新的总标记评分数(T)变为唯一n-gram的数量。 | |
# 我们遍历输入中的所有唯一的标记n-gram,计算每个上下文诱导的绿名单, | |
# 然后检查最后一个标记是否落在该绿名单中。 | |
num_tokens_scored = len(frequencies_table.keys()) | |
green_token_count = sum(ngram_to_watermark_lookup.values()) | |
else: | |
num_tokens_scored = sum(frequencies_table.values()) | |
assert num_tokens_scored == len(input_ids) - self.context_width + self.self_salt | |
green_token_count = sum(freq * outcome for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())) | |
assert green_token_count == green_unique.sum() | |
# HF-style 字典 | |
score_dict = dict() | |
if return_num_tokens_scored: | |
score_dict.update(dict(num_tokens_scored=num_tokens_scored)) | |
if return_num_green_tokens: | |
score_dict.update(dict(num_green_tokens=green_token_count)) | |
if return_green_fraction: | |
score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored))) | |
if return_z_score: | |
score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored))) | |
if return_p_value: | |
z_score = score_dict.get("z_score") | |
if z_score is None: | |
z_score = self._compute_z_score(green_token_count, num_tokens_scored) | |
score_dict.update(dict(p_value=self._compute_p_value(z_score))) | |
if return_green_token_mask: | |
score_dict.update(dict(green_token_mask=green_token_mask.tolist())) | |
if return_z_at_T: | |
# Score z_at_T: | |
sizes = torch.arange(1, len(green_unique) + 1) | |
seq_z_score_enum = torch.cumsum(green_unique, dim=0) - self.gamma * sizes | |
seq_z_score_denom = torch.sqrt(sizes * self.gamma * (1 - self.gamma)) | |
z_score_at_effective_T = seq_z_score_enum / seq_z_score_denom | |
z_score_at_T = z_score_at_effective_T[offsets] | |
assert torch.isclose(z_score_at_T[-1], torch.tensor(z_score)) | |
score_dict.update(dict(z_score_at_T=z_score_at_T)) | |
return score_dict | |
def _score_windows_impl_batched( | |
self, | |
input_ids: torch.Tensor, | |
window_size: str, | |
window_stride: int = 1, | |
): | |
# 实现细节: | |
# 1) --ignore_repeated_ngrams 选项被全局应用,然后在减少的二值向量上应用窗口化处理。 | |
# 这只是实现的一种方式,另一种方法是在每个窗口内忽略bigram(这可能更难并行化处理)。 | |
# 2) 这些窗口在绿色/红色命中的二值向量上,独立于 context_width,与 Kezhi 的第一个实现不同。 | |
# 3) 由于窗口化的处理,这个实现得到的 z-分数不能直接转换为 p-值,并且应该只用作对选定 FPR 进行校准的 ROC 图的标签。 | |
# 由于多次假设测试,整体得分将被提高。 | |
# naive_count_correction=True 是对这个问题的一种部分解决方法。 | |
ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids) | |
green_mask, green_ids, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup) | |
len_full_context = len(green_ids) | |
partial_sum_id_table = torch.cumsum(green_ids, dim=0) | |
if window_size == "max": | |
# 可以稍后启动,小窗口无法产生足够的能量 | |
# solve (T * Spike_Entropy - g * T) / sqrt(T * g * (1 - g)) = z_thresh for T | |
sizes = range(1, len_full_context) | |
else: | |
sizes = [int(x) for x in window_size.split(",") if len(x) > 0] | |
z_score_max_per_window = torch.zeros(len(sizes)) | |
cumulative_eff_z_score = torch.zeros(len_full_context) | |
s = window_stride | |
window_fits = False | |
for idx, size in enumerate(sizes): | |
if size <= len_full_context: | |
# 并行计算窗口内所有位置的hit: | |
window_score = torch.zeros(len_full_context - size + 1, dtype=torch.long) | |
# 包括第0个窗口 | |
window_score[0] = partial_sum_id_table[size - 1] | |
# 从1号开始的所有其他窗口: | |
window_score[1:] = partial_sum_id_table[size::s] - partial_sum_id_table[:-size:s] | |
# 现在计算批处理的z_scores | |
batched_z_score_enum = window_score - self.gamma * size | |
z_score_denom = sqrt(size * self.gamma * (1 - self.gamma)) | |
batched_z_score = batched_z_score_enum / z_score_denom | |
# 找到最大的hit | |
maximal_z_score = batched_z_score.max() | |
z_score_max_per_window[idx] = maximal_z_score | |
z_score_at_effective_T = torch.cummax(batched_z_score, dim=0)[0] | |
cumulative_eff_z_score[size::s] = torch.maximum(cumulative_eff_z_score[size::s], z_score_at_effective_T[:-1]) | |
window_fits = True # 成功计算所有大小的窗口 | |
if not window_fits: | |
raise ValueError( | |
f"Could not find a fitting window with window sizes {window_size} for (effective) context length {len_full_context}." | |
) | |
# 计算最佳窗口大小和z得分 | |
cumulative_z_score = cumulative_eff_z_score[offsets] | |
optimal_z, optimal_window_size_idx = z_score_max_per_window.max(dim=0) | |
optimal_window_size = sizes[optimal_window_size_idx] | |
return ( | |
optimal_z, | |
optimal_window_size, | |
z_score_max_per_window, | |
cumulative_z_score, | |
green_mask, | |
) | |
def _score_sequence_window( | |
self, | |
input_ids: torch.Tensor, | |
return_num_tokens_scored: bool = True, | |
return_num_green_tokens: bool = True, | |
return_green_fraction: bool = True, | |
return_green_token_mask: bool = False, | |
return_z_score: bool = True, | |
return_z_at_T: bool = True, | |
return_p_value: bool = True, | |
window_size: str = None, | |
window_stride: int = 1, | |
): | |
( | |
optimal_z, | |
optimal_window_size, | |
_, | |
z_score_at_T, | |
green_mask, | |
) = self._score_windows_impl_batched(input_ids, window_size, window_stride) | |
# HF-style 字典 | |
score_dict = dict() | |
if return_num_tokens_scored: | |
score_dict.update(dict(num_tokens_scored=optimal_window_size)) | |
denom = sqrt(optimal_window_size * self.gamma * (1 - self.gamma)) | |
green_token_count = int(optimal_z * denom + self.gamma * optimal_window_size) | |
green_fraction = green_token_count / optimal_window_size | |
if return_num_green_tokens: | |
score_dict.update(dict(num_green_tokens=green_token_count)) | |
if return_green_fraction: | |
score_dict.update(dict(green_fraction=green_fraction)) | |
if return_z_score: | |
score_dict.update(dict(z_score=optimal_z)) | |
if return_z_at_T: | |
score_dict.update(dict(z_score_at_T=z_score_at_T)) | |
if return_p_value: | |
z_score = score_dict.get("z_score", optimal_z) | |
score_dict.update(dict(p_value=self._compute_p_value(z_score))) | |
# 返回掩码的每个标记的结果。这仍然是相同的,只是通过窗口化进行评分。 | |
# 待办事项是将实际被计数的标记以不同的方式标记。 | |
if return_green_token_mask: | |
score_dict.update(dict(green_token_mask=green_mask.tolist())) | |
return score_dict | |
def detect( | |
self, | |
text: str = None, | |
tokenized_text: list[int] = None, | |
window_size: str = None, | |
window_stride: int = None, | |
return_prediction: bool = True, | |
return_scores: bool = True, | |
z_threshold: float = None, | |
convert_to_float: bool = False, | |
**kwargs, | |
) -> dict: | |
"""对给定的文本字符串进行评分,并返回结果字典""" | |
assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string" | |
if return_prediction: | |
kwargs["return_p_value"] = True # 返回阳性检测的 "confidence":=1-p | |
# 对文本运行可选的normalizers | |
for normalizer in self.normalizers: | |
text = normalizer(text) | |
if len(self.normalizers) > 0: | |
print(f"Text after normalization:\n\n{text}\n") | |
if tokenized_text is None: | |
assert self.tokenizer is not None, ( | |
"Watermark detection on raw string ", | |
"requires an instance of the tokenizer ", | |
"that was used at generation time.", | |
) | |
tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device) | |
if tokenized_text[0] == self.tokenizer.bos_token_id: | |
tokenized_text = tokenized_text[1:] | |
else: | |
# 尝试从一开始就删除bos_tok(如果它在那里的话) | |
if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id): | |
tokenized_text = tokenized_text[1:] | |
# 调用score方法 | |
output_dict = {} | |
if window_size is not None: | |
score_dict = self._score_sequence_window( | |
tokenized_text, | |
window_size=window_size, | |
window_stride=window_stride, | |
**kwargs, | |
) | |
output_dict.update(score_dict) | |
else: | |
score_dict = self._score_sequence(tokenized_text, **kwargs) | |
if return_scores: | |
output_dict.update(score_dict) | |
# 如果通过return_prediction,则执行假设检验并返回结果 | |
if return_prediction: | |
z_threshold = z_threshold if z_threshold else self.z_threshold | |
assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test" | |
output_dict["prediction"] = score_dict["z_score"] > z_threshold | |
if output_dict["prediction"]: | |
output_dict["confidence"] = 1 - score_dict["p_value"] | |
# 如果需要的话,将任何数值转换为浮点值 | |
if convert_to_float: | |
for key, value in output_dict.items(): | |
if isinstance(value, int): | |
output_dict[key] = float(value) | |
return output_dict | |
def ngrams(sequence, n, pad_left=False, pad_right=False, pad_symbol=None): | |
sequence = iter(sequence) | |
if pad_left: | |
sequence = chain((pad_symbol,) * (n - 1), sequence) | |
if pad_right: | |
sequence = chain(sequence, (pad_symbol,) * (n - 1)) | |
iterables = tee(sequence, n) | |
for i, sub_iterable in enumerate(iterables): # For each window, | |
for _ in range(i): # iterate through every order of ngrams | |
next(sub_iterable, None) # generate the ngrams within the window. | |
return zip(*iterables) # Unpack and flattens the iterables. | |