""" Main Flask application for the watermark detection web interface. """ from flask import Flask, render_template, request, jsonify, Response, stream_with_context from transformers import AutoModelForCausalLM, AutoTokenizer import torch import json from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator from .utils import get_token_details, template_prompt CACHE_DIR = "wm_interactive/static/hf_cache" def convert_nan_to_null(obj): """Convert NaN values to null for JSON serialization""" import math if isinstance(obj, float) and math.isnan(obj): return None elif isinstance(obj, dict): return {k: convert_nan_to_null(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_nan_to_null(item) for item in obj] return obj def set_to_int(value, default_value = None): try: return int(value) except (ValueError, TypeError): return default_value def create_detector(detector_type, tokenizer, **kwargs): """Create a detector instance based on the specified type.""" detector_map = { 'maryland': MarylandDetector, 'marylandz': MarylandDetectorZ, 'openai': OpenaiDetector, 'openaiz': OpenaiDetectorZ } # Validate and set default values for parameters if 'seed' in kwargs: kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0) if 'ngram' in kwargs: kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1) detector_class = detector_map.get(detector_type, MarylandDetector) return detector_class(tokenizer=tokenizer, **kwargs) def create_app(): app = Flask(__name__, static_folder='../static', template_folder='../templates') # Add zip to Jinja's global context app.jinja_env.globals.update(zip=zip) # Pick a model # model_id = "meta-llama/Llama-3.2-1B-Instruct" model_id = "HuggingFaceTB/SmolLM2-135M-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR) model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu") # Create default generator generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0) @app.route("/", methods=["GET"]) def index(): return render_template("index.html") @app.route("/tokenize", methods=["POST"]) def tokenize(): try: data = request.get_json() if not data: return jsonify({'error': 'No JSON data received'}), 400 text = data.get('text', '') params = data.get('params', {}) # Create a detector instance with the provided parameters detector = create_detector( detector_type=params.get('detector_type', 'maryland'), tokenizer=tokenizer, seed=params.get('seed', 0), ngram=params.get('ngram', 1) ) if text: try: display_info = get_token_details(text, detector) # Extract summary stats (last item in display_info) stats = display_info.pop() response_data = { 'token_count': len(display_info), 'tokens': [info['token'] for info in display_info], 'colors': [info['color'] for info in display_info], 'scores': [info['score'] if info.get('is_scored', False) else None for info in display_info], 'pvalues': [info['pvalue'] if info.get('is_scored', False) else None for info in display_info], 'final_score': stats.get('final_score', 0) if stats.get('final_score') is not None else 0, 'ntoks_scored': stats.get('ntoks_scored', 0) if stats.get('ntoks_scored') is not None else 0, 'final_pvalue': stats.get('final_pvalue', 0.5) if stats.get('final_pvalue') is not None else 0.5 } # Convert any NaN values to null before sending response_data = convert_nan_to_null(response_data) # Ensure numeric fields have default values if they became null if response_data['final_score'] is None: response_data['final_score'] = 0 if response_data['ntoks_scored'] is None: response_data['ntoks_scored'] = 0 if response_data['final_pvalue'] is None: response_data['final_pvalue'] = 0.5 return jsonify(response_data) except Exception as e: app.logger.error(f'Error processing text: {str(e)}') return jsonify({'error': f'Error processing text: {str(e)}'}), 500 return jsonify({ 'token_count': 0, 'tokens': [], 'colors': [], 'scores': [], 'pvalues': [], 'final_score': 0, 'ntoks_scored': 0, 'final_pvalue': 0.5 }) except Exception as e: app.logger.error(f'Server error: {str(e)}') return jsonify({'error': f'Server error: {str(e)}'}), 500 @app.route("/generate", methods=["POST"]) def generate(): try: data = request.get_json() if not data: return jsonify({'error': 'No JSON data received'}), 400 prompt = template_prompt(data.get('prompt', '')) params = data.get('params', {}) temperature = float(params.get('temperature', 0.8)) def generate_stream(): try: # Create generator with correct parameters generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator generator = generator_class( model=model, tokenizer=tokenizer, ngram=set_to_int(params.get('ngram', 1)), seed=set_to_int(params.get('seed', 0)), delta=float(params.get('delta', 2.0)), ) # Get special tokens to filter out special_tokens = { '<|im_start|>', '<|im_end|>', tokenizer.pad_token, tokenizer.eos_token, tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None, tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None } special_tokens = {t for t in special_tokens if t is not None} # Encode prompt prompt_tokens = tokenizer.encode(prompt) prompt_size = len(prompt_tokens) max_gen_len = 100 total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size) # Initialize generation tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long() tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long() input_text_mask = tokens != model.config.pad_token_id # Generate token by token prev_pos = 0 outputs = None # Initialize outputs to None for cur_pos in range(prompt_size, total_len): # Get model outputs outputs = model.forward( tokens[:, prev_pos:cur_pos], use_cache=True, past_key_values=outputs.past_key_values if prev_pos > 0 else None ) # Sample next token using the generator's sampling method ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist() aux = { 'ngram_tokens': ngram_tokens, 'cur_pos': cur_pos, } next_token = generator.sample_next( outputs.logits[:, -1, :], aux, temperature=temperature, top_p=0.9 ) # Check for EOS token if next_token == model.config.eos_token_id: break # Decode and check if it's a special token new_text = tokenizer.decode([next_token]) if new_text not in special_tokens and not any(st in new_text for st in special_tokens): yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n" # Update token and position tokens[0, cur_pos] = next_token prev_pos = cur_pos # Send final complete text, filtering out special tokens final_tokens = tokens[0, prompt_size:cur_pos+1].tolist() final_text = tokenizer.decode(final_tokens) for st in special_tokens: final_text = final_text.replace(st, '') yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n" except Exception as e: app.logger.error(f'Error generating text: {str(e)}') yield f"data: {json.dumps({'error': str(e)})}\n\n" return Response(stream_with_context(generate_stream()), mimetype='text/event-stream') except Exception as e: app.logger.error(f'Server error: {str(e)}') return jsonify({'error': f'Server error: {str(e)}'}), 500 return app app = create_app() if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)