jer233 commited on
Commit
98bdd9f
·
verified ·
1 Parent(s): 515bde6

Update demo/demo.py

Browse files
Files changed (1) hide show
  1. demo/demo.py +84 -69
demo/demo.py CHANGED
@@ -1,86 +1,101 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- from MMD_calculate import MMDMPDetector
4
 
5
- detector = MMDMPDetector() # Initialize your MMD-MP detector
6
- MINIMUM_TOKENS = 64 # Minimum number of tokens for detection
7
 
8
- def count_tokens(text):
9
- return len(text.split()) # Count the number of tokens (words) in the text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def run_detector(input_text):
12
- # Check if input meets the token requirement
13
- if count_tokens(input_text) < MINIMUM_TOKENS:
14
- return f"Error: Text is too short! At least {MINIMUM_TOKENS} tokens are required."
15
-
16
- # Perform detection (replace this with your model's prediction logic)
17
- prediction = detector.predict(input_text)
18
- return f"Result: {prediction}"
19
 
20
- def change_mode(mode):
21
- if mode == "Low False Positive Rate":
22
- detector.set_mode("low-fpr") # Adjust detector mode
23
- elif mode == "High Accuracy":
24
- detector.set_mode("accuracy")
25
- return f"Mode set to: {mode}"
26
 
27
  css = """
28
- .green { color: black!important; line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
29
- .red { color: black!important; line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
30
- .hyperlinks {
31
- display: flex;
32
- align-items: center;
33
- justify-content: flex-end;
34
- padding: 12px;
35
- margin: 0 10px;
36
- text-decoration: none;
37
- color: #000;
38
- }
39
  """
40
 
41
- with gr.Blocks(css=css, theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as app:
42
- # Header Row
43
  with gr.Row():
44
- with gr.Column(scale=3):
45
- gr.HTML("<h1>Binoculars: Zero-Shot LLM-Text Detector</h1>")
46
- with gr.Column(scale=1):
47
- gr.HTML("""
48
- <p class="hyperlinks">
49
- <a href="https://arxiv.org/abs/2401.12070" target="_blank">Paper</a> |
50
- <a href="https://github.com/AHans30/Binoculars" target="_blank">Code</a> |
51
- <a href="mailto:contact@example.com" target="_blank">Contact</a>
52
- </p>
53
- """)
54
-
55
- # Input Section
56
  with gr.Row():
57
- input_text = gr.Textbox(placeholder="Enter text here...", lines=8, label="Input Text")
58
-
59
- # Mode Selector and Buttons
 
 
60
  with gr.Row():
61
- mode_selector = gr.Dropdown(
62
- choices=["Low False Positive Rate", "High Accuracy"],
63
- label="Detection Mode",
64
- value="Low False Positive Rate"
65
  )
66
- submit_button = gr.Button("Run Binoculars", variant="primary")
67
- clear_button = gr.Button("Clear")
68
-
69
- # Output Section
70
  with gr.Row():
71
- output_text = gr.Textbox(label="Prediction", value="Results will appear here...")
72
-
73
- # Disclaimer Section
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Accordion("Disclaimer", open=False):
75
- gr.Markdown("""
76
- - **Accuracy**: This detector uses state-of-the-art techniques, but no model is perfect.
77
- - **Mode Information**:
78
- - High Accuracy: Maximizes accuracy by adjusting thresholds.
79
- - Low False Positive Rate: Reduces human-written text being falsely flagged as AI-generated.
80
- - **Limitations**: Detection is best on texts with 64–300 tokens. Very short or extremely long texts may lead to inaccurate results.
81
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Bind Functions to Buttons
84
- submit_button.click(run_detector, inputs=input_text, outputs=output_text)
85
- clear_button.click(lambda: ("", ""), outputs=[input_text, output_text])
86
- mode_selector.change(change_mode, inputs=mode_selector, outputs=mode_selector)
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
+ # from MMD_calculate import mmd_two_sample_baseline # Adjust path based on your structure
5
+ # from utils_MMD import extract_features # Example helper from your utils
6
 
7
+ MINIMUM_TOKENS = 64
8
 
9
+ def count_tokens(text, tokenizer):
10
+ return len(tokenizer(text).input_ids)
11
 
12
+ def run_test_power(model_name, tokenizer_name, real_text, generated_text, N):
13
+ """
14
+ Runs the test power calculation for provided real and generated texts.
15
+ """
16
+
17
+ # load tokenizer and model
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name).cuda()
19
+ model = AutoModel.from_pretrained(model)
20
+
21
+ if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS:
22
+ return "Too short length. Need minimum 64 tokens to calculated Test Power."
23
+
24
+ # Extract features
25
+ fea_real_ls = extract_features(model_name, tokenizer_name, [real_text])
26
+ fea_generated_ls = extract_features(model_name, tokenizer_name, [generated_text])
27
+
28
+ # Calculate test power list
29
+ test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=10)
30
+
31
+ # Compute the average test power value
32
+ power_test_value = sum(test_power_ls) / len(test_power_ls)
33
+
34
+ # Classify the text
35
+ if power_test_value < threshold:
36
+ return "Prediction: Human"
37
+ else:
38
+ return "Prediction: AI"
39
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
41
 
42
  css = """
43
+ #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; }
44
+ #output-text { font-weight: bold; font-size: 1.2em; }
 
 
 
 
 
 
 
 
 
45
  """
46
 
47
+ # Gradio App
48
+ with gr.Blocks(css=css) as app:
49
  with gr.Row():
50
+ gr.HTML('<div id="header">Human or AI Text Detector</div>')
 
 
 
 
 
 
 
 
 
 
 
51
  with gr.Row():
52
+ gr.Markdown(
53
+ """
54
+ [Paper](https://openreview.net/forum?id=z9j7wctoGV) | [Code](https://github.com/xLearn-AU/R-Detect) | [Contact](mailto:1730421718@qq.com)
55
+ """
56
+ )
57
  with gr.Row():
58
+ input_text = gr.Textbox(
59
+ label="Input Text",
60
+ placeholder="Enter the text to check",
61
+ lines=8,
62
  )
 
 
 
 
63
  with gr.Row():
64
+ model_name = gr.Dropdown(
65
+ ["gpt2-medium", "gpt2-large", "t5-large", "t5-small", "roberta-base", "roberta-base-openai-detector", "falcon-rw-1b"],
66
+ label="Select Model",
67
+ value="gpt2-medium",
68
+ )
69
+ with gr.Row():
70
+ submit_button = gr.Button("Run Detection", variant="primary")
71
+ clear_button = gr.Button("Clear", variant="secondary")
72
+ with gr.Row():
73
+ output = gr.Textbox(
74
+ label = "Prediction",
75
+ placeholder = "Prediction: Human or AI",
76
+ elem_id = "output-text",
77
+ )
78
  with gr.Accordion("Disclaimer", open=False):
79
+ gr.Markdown(
80
+ """
81
+ - **Disclaimer**: This tool is for demonstration purposes only. It is not a foolproof AI detector.
82
+ - **Accuracy**: Results may vary based on input length and quality.
83
+ """
84
+ )
85
+ with gr.Accordion("Citations", open=False):
86
+ gr.Markdown(
87
+ """
88
+ ```
89
+ @inproceedings{zhangs2024MMDMP,
90
+ title={Detecting Machine-Generated Texts by Multi-Population Aware Optimization for Maximum Mean Discrepancy},
91
+ author={Zhang, Shuhai and Song, Yiliao and Yang, Jiahao and Li, Yuanqing and Han, Bo and Tan, Mingkui},
92
+ booktitle = {International Conference on Learning Representations (ICLR)},
93
+ year={2024}
94
+ }
95
+ ```
96
+ """
97
+ )
98
+ submit_button.click(detect_text, inputs=[input_text, model_name], outputs=output)
99
+ clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output])
100
 
101
+ app.launch()