Spaces:
Running
Running
Commit
·
3b2f15b
1
Parent(s):
930671d
implemented detection
Browse files- wm_detector/core/detector.py +1 -4
- wm_detector/core/utils.py +1 -1
- wm_detector/web/app.py +18 -17
wm_detector/core/detector.py
CHANGED
@@ -6,12 +6,9 @@
|
|
6 |
|
7 |
import numpy as np
|
8 |
from scipy import special
|
9 |
-
from scipy.optimize import fminbound
|
10 |
|
11 |
import torch
|
12 |
-
from transformers import AutoTokenizer
|
13 |
-
|
14 |
-
import random
|
15 |
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
|
|
6 |
|
7 |
import numpy as np
|
8 |
from scipy import special
|
|
|
9 |
|
10 |
import torch
|
11 |
+
from transformers import AutoTokenizer
|
|
|
|
|
12 |
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
wm_detector/core/utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import random
|
2 |
import numpy as np
|
3 |
-
from .detector import
|
4 |
|
5 |
def generate_pastel_color():
|
6 |
"""Generate a pastel color in HSL format."""
|
|
|
1 |
import random
|
2 |
import numpy as np
|
3 |
+
from .detector import WmDetector
|
4 |
|
5 |
def generate_pastel_color():
|
6 |
"""Generate a pastel color in HSL format."""
|
wm_detector/web/app.py
CHANGED
@@ -3,9 +3,22 @@ Main Flask application for the watermark detection web interface.
|
|
3 |
"""
|
4 |
|
5 |
from flask import Flask, render_template, request, jsonify
|
6 |
-
from
|
|
|
|
|
7 |
from ..core.utils import get_token_details
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def create_detector(detector_type, tokenizer, **kwargs):
|
10 |
"""Create a detector instance based on the specified type."""
|
11 |
detector_map = {
|
@@ -39,20 +52,11 @@ def create_app():
|
|
39 |
# Add zip to Jinja's global context
|
40 |
app.jinja_env.globals.update(zip=zip)
|
41 |
|
42 |
-
#
|
43 |
-
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
|
|
44 |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="wm_detector/static/hf_cache")
|
45 |
-
|
46 |
-
def convert_nan_to_null(obj):
|
47 |
-
"""Convert NaN values to null for JSON serialization"""
|
48 |
-
import math
|
49 |
-
if isinstance(obj, float) and math.isnan(obj):
|
50 |
-
return None
|
51 |
-
elif isinstance(obj, dict):
|
52 |
-
return {k: convert_nan_to_null(v) for k, v in obj.items()}
|
53 |
-
elif isinstance(obj, list):
|
54 |
-
return [convert_nan_to_null(item) for item in obj]
|
55 |
-
return obj
|
56 |
|
57 |
@app.route("/", methods=["GET"])
|
58 |
def index():
|
@@ -68,9 +72,6 @@ def create_app():
|
|
68 |
text = data.get('text', '')
|
69 |
params = data.get('params', {})
|
70 |
|
71 |
-
if not isinstance(text, str):
|
72 |
-
return jsonify({'error': 'Text must be a string'}), 400
|
73 |
-
|
74 |
# Create a detector instance with the provided parameters
|
75 |
detector = create_detector(
|
76 |
detector_type=params.get('detector_type', 'maryland'),
|
|
|
3 |
"""
|
4 |
|
5 |
from flask import Flask, render_template, request, jsonify
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
|
8 |
+
from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ
|
9 |
from ..core.utils import get_token_details
|
10 |
|
11 |
+
def convert_nan_to_null(obj):
|
12 |
+
"""Convert NaN values to null for JSON serialization"""
|
13 |
+
import math
|
14 |
+
if isinstance(obj, float) and math.isnan(obj):
|
15 |
+
return None
|
16 |
+
elif isinstance(obj, dict):
|
17 |
+
return {k: convert_nan_to_null(v) for k, v in obj.items()}
|
18 |
+
elif isinstance(obj, list):
|
19 |
+
return [convert_nan_to_null(item) for item in obj]
|
20 |
+
return obj
|
21 |
+
|
22 |
def create_detector(detector_type, tokenizer, **kwargs):
|
23 |
"""Create a detector instance based on the specified type."""
|
24 |
detector_map = {
|
|
|
52 |
# Add zip to Jinja's global context
|
53 |
app.jinja_env.globals.update(zip=zip)
|
54 |
|
55 |
+
# Pick a model
|
56 |
+
# model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
57 |
+
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
58 |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="wm_detector/static/hf_cache")
|
59 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
@app.route("/", methods=["GET"])
|
62 |
def index():
|
|
|
72 |
text = data.get('text', '')
|
73 |
params = data.get('params', {})
|
74 |
|
|
|
|
|
|
|
75 |
# Create a detector instance with the provided parameters
|
76 |
detector = create_detector(
|
77 |
detector_type=params.get('detector_type', 'maryland'),
|