izhx commited on
Commit
8bf2130
1 Parent(s): 29b0301
Files changed (2) hide show
  1. app.py +231 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Hello-SimpleAI Org. 2023.
2
+ # Licensed under the Apache License, Version 2.0.
3
+
4
+ import os
5
+ import pickle
6
+ import re
7
+ from typing import Callable, List, Tuple
8
+
9
+ import gradio as gr
10
+ from nltk.data import load as nltk_load
11
+ import numpy as np
12
+ from sklearn.linear_model import LogisticRegression
13
+ import torch
14
+ from transformers.utils import cached_file
15
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
16
+
17
+
18
+ AUTH_TOKEN = os.environ.get("access_token")
19
+ DET_LING_ID = 'Hello-SimpleAI/chatgpt-detector-ling'
20
+
21
+
22
+ def download_file(filename):
23
+ return cached_file(DET_LING_ID, filename, use_auth_token=AUTH_TOKEN)
24
+
25
+
26
+ NLTK = nltk_load(download_file('english.pickle'))
27
+ sent_cut_en = NLTK.tokenize
28
+ LR_GLTR_EN, LR_PPL_EN = [
29
+ pickle.load(open(download_file(f'{lang}-gpt2-{name}.pkl'), 'rb'))
30
+ for lang, name in [('en', 'gltr'), ('en', 'ppl')]
31
+ ]
32
+
33
+ NAME_EN = 'gpt2'
34
+ TOKENIZER_EN = GPT2Tokenizer.from_pretrained(NAME_EN)
35
+ MODEL_EN = GPT2LMHeadModel.from_pretrained(NAME_EN)
36
+
37
+
38
+ # code borrowed from https://github.com/blmoistawinde/HarvestText
39
+ def sent_cut_zh(para: str) -> List[str]:
40
+ para = re.sub('([。!?\?!])([^”’)\])】])', r"\1\n\2", para) # 单字符断句符
41
+ para = re.sub('(\.{3,})([^”’)\])】….])', r"\1\n\2", para) # 英文省略号
42
+ para = re.sub('(\…+)([^”’)\])】….])', r"\1\n\2", para) # 中文省略号
43
+ para = re.sub('([。!?\?!]|\.{3,}|\…+)([”’)\])】])([^,。!?\?….])', r'\1\2\n\3', para)
44
+ # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
45
+ para = para.rstrip() # 段尾如果有多余的\n就去掉它
46
+ # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
47
+ sentences = para.split("\n")
48
+ sentences = [sent.strip() for sent in sentences]
49
+ sentences = [sent for sent in sentences if len(sent.strip()) > 0]
50
+ return sentences
51
+
52
+
53
+ CROSS_ENTROPY = torch.nn.CrossEntropyLoss(reduction='none')
54
+
55
+
56
+ def gpt2_features(
57
+ text: str, tokenizer: GPT2Tokenizer, model: GPT2LMHeadModel, sent_cut: Callable
58
+ ) -> Tuple[List[int], List[float]]:
59
+ # Tokenize
60
+ input_max_length = tokenizer.model_max_length - 2
61
+ token_ids, offsets = list(), list()
62
+ sentences = sent_cut(text)
63
+ for s in sentences:
64
+ tokens = tokenizer.tokenize(s)
65
+ ids = tokenizer.convert_tokens_to_ids(tokens)
66
+ difference = len(token_ids) + len(ids) - input_max_length
67
+ if difference > 0:
68
+ ids = ids[:-difference]
69
+ offsets.append((len(token_ids), len(token_ids) + len(ids))) # 左开右闭
70
+ token_ids.extend(ids)
71
+ if difference >= 0:
72
+ break
73
+
74
+ input_ids = torch.tensor([tokenizer.bos_token_id] + token_ids)
75
+ logits = model(input_ids).logits
76
+ # Shift so that n-1 predict n
77
+ shift_logits = logits[:-1].contiguous()
78
+ shift_target = input_ids[1:].contiguous()
79
+ loss = CROSS_ENTROPY(shift_logits, shift_target)
80
+
81
+ all_probs = torch.softmax(shift_logits, dim=-1)
82
+ sorted_ids = torch.argsort(all_probs, dim=-1, descending=True) # stable=True
83
+ expanded_tokens = shift_target.unsqueeze(-1).expand_as(sorted_ids)
84
+ indices = torch.where(sorted_ids == expanded_tokens)
85
+ rank = indices[-1]
86
+ counter = [
87
+ rank < 10,
88
+ (rank >= 10) & (rank < 100),
89
+ (rank >= 100) & (rank < 1000),
90
+ rank >= 1000
91
+ ]
92
+ counter = [c.long().sum(-1).item() for c in counter]
93
+
94
+
95
+ # compute different-level ppl
96
+ text_ppl = loss.mean().exp().item()
97
+ sent_ppl = list()
98
+ for start, end in offsets:
99
+ nll = loss[start: end].sum() / (end - start)
100
+ sent_ppl.append(nll.exp().item())
101
+ max_sent_ppl = max(sent_ppl)
102
+ sent_ppl_avg = sum(sent_ppl) / len(sent_ppl)
103
+ if len(sent_ppl) > 1:
104
+ sent_ppl_std = torch.std(torch.tensor(sent_ppl)).item()
105
+ else:
106
+ sent_ppl_std = 0
107
+
108
+ mask = torch.tensor([1] * loss.size(0))
109
+ step_ppl = loss.cumsum(dim=-1).div(mask.cumsum(dim=-1)).exp()
110
+ max_step_ppl = step_ppl.max(dim=-1)[0].item()
111
+ step_ppl_avg = step_ppl.sum(dim=-1).div(loss.size(0)).item()
112
+ if step_ppl.size(0) > 1:
113
+ step_ppl_std = step_ppl.std().item()
114
+ else:
115
+ step_ppl_std = 0
116
+ ppls = [
117
+ text_ppl, max_sent_ppl, sent_ppl_avg, sent_ppl_std,
118
+ max_step_ppl, step_ppl_avg, step_ppl_std
119
+ ]
120
+ return counter, ppls # type: ignore
121
+
122
+
123
+ def lr_predict(
124
+ f_gltr: List[int], f_ppl: List[float], lr_gltr: LogisticRegression, lr_ppl: LogisticRegression,
125
+ id_to_label: List[str]
126
+ ) -> List:
127
+ x_gltr = np.asarray([f_gltr])
128
+ gltr_label = lr_gltr.predict(x_gltr)[0]
129
+ gltr_prob = lr_gltr.predict_proba(x_gltr)[0, gltr_label]
130
+ x_ppl = np.asarray([f_ppl])
131
+ ppl_label = lr_ppl.predict(x_ppl)[0]
132
+ ppl_prob = lr_ppl.predict_proba(x_ppl)[0, ppl_label]
133
+ return [id_to_label[gltr_label], gltr_prob, id_to_label[ppl_label], ppl_prob]
134
+
135
+
136
+ def predict_en(text: str) -> List:
137
+ with torch.no_grad():
138
+ feat = gpt2_features(text, TOKENIZER_EN, MODEL_EN, sent_cut_en)
139
+ out = lr_predict(*feat, LR_GLTR_EN, LR_PPL_EN, ['Human', 'ChatGPT'])
140
+ return out
141
+
142
+
143
+ def predict_zh(text: str) -> List:
144
+ with torch.no_grad():
145
+ feat = gpt2_features(text, TOKENIZER_ZH, MODEL_ZH, sent_cut_zh)
146
+ out = lr_predict(*feat, None, None, ['人类', 'ChatGPT'])
147
+ return out
148
+
149
+
150
+ with gr.Blocks() as demo:
151
+ gr.Markdown(
152
+ """
153
+ ## ChatGPT Detector 🔬 (Linguistic version)
154
+ Visit our project on Github: [chatgpt-comparison-detection project](https://github.com/Hello-SimpleAI/chatgpt-comparison-detection)<br>
155
+ 欢迎在 Github 上关注我们的 [ChatGPT 对比与检测项目](https://github.com/Hello-SimpleAI/chatgpt-comparison-detection)
156
+ We provide three kinds of detectors, all in Bilingual / 我们提供了三个版本的检测器,且都支持中英文:
157
+ - [QA version / 问答版](https://huggingface.co/spaces/Hello-SimpleAI/chatgpt-detector-qa)<br>
158
+ detect whether an **answer** is generated by ChatGPT for certain **question**, using PLM-based classifiers / 判断某个**问题的回答**是否由ChatGPT生成,使用基于PTM的分类器来开发;
159
+ - [Sinlge-text version / 独立文本版](https://huggingface.co/spaces/Hello-SimpleAI/chatgpt-detector-single)<br>
160
+ detect whether a piece of text is ChatGPT generated, using PLM-based classifiers / 判断**单条文本**是否由ChatGPT生成,使用基于PTM的分类器来开发;
161
+ - [**Linguistic version / 语言学版** (👈 Current / 当前使用)](https://huggingface.co/spaces/Hello-SimpleAI/chatgpt-detector-ling)<br>
162
+ detect whether a piece of text is ChatGPT generated, using linguistic features / 判断**单条文本**是否由ChatGPT生成,使用基于语言学特征的模型来开发;
163
+
164
+ ## Introduction:
165
+ Two Logistic regression models trained with two kinds of features:
166
+ 1. [GLTR](https://aclanthology.org/P19-3019) Test-2, Language model predict token rank top-k buckets, top 10, 10-100, 100-1000, 1000+.
167
+ 2. PPL-based, text ppl, `avg` & `max` & `std` of sentence ppls, `avg` & `max` &`std` of timestep ppls.
168
+
169
+ English LM is [GPT2-small](https://huggingface.co/gpt2).
170
+
171
+ ## 介绍:
172
+ 两个逻辑回归模型, 分别使用以下两种特征:
173
+ 1. [GLTR](https://aclanthology.org/P19-3019) Test-2, 每个词的语言模型预测排名分桶, top 10, 10-100, 100-1000, 1000+.
174
+ 2. 基于语言模型困惑度 (PPL), text ppl, `avg` & `max` & `std` of sentence ppls, `avg` & `max` &`std` of timestep ppls.
175
+
176
+ 中文语言模型使用 闻仲 [Wenzhong-GPT2-110M](https://huggingface.co/IDEA-CCNL/Wenzhong-GPT2-110M).
177
+
178
+ """
179
+ )
180
+
181
+ with gr.Tab("English"):
182
+ gr.Markdown(
183
+ """
184
+ Note: Providing more text to the `Text` box can make the prediction more accurate!
185
+ """
186
+ )
187
+ a1 = gr.Textbox(lines=5, label='Text', value="""
188
+ There are a few things that can help protect your credit card information from being misused when you give it to a restaurant or any other business:
189
+ \nEncryption: Many businesses use encryption to protect your credit card information when it is being transmitted or stored.
190
+ This means that the information is transformed into a code that is difficult for anyone to read without the right key.
191
+ """
192
+ )
193
+ button1 = gr.Button("🤖 Predict!")
194
+ label1_gltr = gr.Textbox(lines=1, label='GLTR Predicted Label 🎃')
195
+ score1_gltr = gr.Textbox(lines=1, label='GLTR Probability')
196
+ label1_ppl = gr.Textbox(lines=1, label='PPL Predicted Label 🎃')
197
+ score1_ppl = gr.Textbox(lines=1, label='PPL Probability')
198
+
199
+ with gr.Tab("中文版"):
200
+ gr.Markdown(
201
+ """
202
+ 注意: 在`文本`栏中输入更多的文本,可以让预测更准确哦!
203
+ """
204
+ )
205
+ a2 = gr.Textbox(lines=5, label='文本',value="""
206
+ 对于OpenAI大力出奇迹的工作,自然每个人都有自己的看点。
207
+ 我自己最欣赏的地方是ChatGPT如何解决 “AI校正(Alignment)“这个问题。
208
+ 这个问题也是我们课题组这两年在探索的学术问题之一。
209
+ """
210
+ )
211
+ button2 = gr.Button("🤖 预测!")
212
+ label2_gltr = gr.Textbox(lines=1, label='GLTR 预测结果 🎃')
213
+ score2_gltr = gr.Textbox(lines=1, label='GLTR 模型概率')
214
+ label2_ppl = gr.Textbox(lines=1, label='PPL 预测结果 🎃')
215
+ score2_ppl = gr.Textbox(lines=1, label='PPL 模型概率')
216
+
217
+ button1.click(predict_en, inputs=[a1], outputs=[label1_gltr, score1_gltr, label1_ppl, score1_ppl])
218
+ button2.click(predict_zh, inputs=[a2], outputs=[label2_gltr, score2_gltr, label2_ppl, score2_ppl])
219
+
220
+ # Page Count
221
+ gr.Markdown(
222
+ """
223
+ <center>
224
+ <a href='https://clustrmaps.com/site/1bsdd' title='Visit tracker'>
225
+ < img src='//clustrmaps.com/map_v2.png?cl=080808&w=a&t=tt&d=NvxUHBTxY0ECXEuebgz8Ym8ynpVtduq59ENXoQpFh74&co=ffffff&ct=808080'/>
226
+ </a>
227
+ </center>
228
+ """
229
+ )
230
+
231
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers>=4.20,<4.26.0
2
+ nltk>=3.0,<=4.0
3
+ scikit-learn>=1.0,<=1.2