DPSGDTool / app /routes.py
Shuya Feng
udpate
b0b2c21
raw
history blame
5.58 kB
from flask import Blueprint, render_template, jsonify, request, current_app
from app.training.mock_trainer import MockTrainer
from app.training.privacy_calculator import PrivacyCalculator
from flask_cors import cross_origin
import os
# Try to import RealTrainer, fallback to MockTrainer if dependencies aren't available
try:
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer
REAL_TRAINER_AVAILABLE = True
print("Simplified real trainer available - will use MNIST dataset")
except ImportError as e:
print(f"Real trainer not available ({e}) - trying simplified version")
try:
from app.training.real_trainer import RealTrainer
REAL_TRAINER_AVAILABLE = True
print("Full real trainer available - will use MNIST dataset")
except ImportError as e2:
print(f"No real trainer available ({e2}) - using mock trainer")
REAL_TRAINER_AVAILABLE = False
main = Blueprint('main', __name__)
mock_trainer = MockTrainer()
privacy_calculator = PrivacyCalculator()
# Initialize real trainer if available
if REAL_TRAINER_AVAILABLE:
try:
real_trainer = RealTrainer()
print("Real trainer initialized successfully")
except Exception as e:
print(f"Failed to initialize real trainer: {e}")
REAL_TRAINER_AVAILABLE = False
real_trainer = None
else:
real_trainer = None
@main.route('/')
def index():
return render_template('index.html')
@main.route('/learning')
def learning():
return render_template('learning.html')
@main.route('/api/train', methods=['POST', 'OPTIONS'])
@cross_origin()
def train():
if request.method == 'OPTIONS':
return jsonify({'status': 'ok'})
try:
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
params = {
'clipping_norm': float(data.get('clipping_norm', 1.0)),
'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
'batch_size': int(data.get('batch_size', 64)),
'learning_rate': float(data.get('learning_rate', 0.01)),
'epochs': int(data.get('epochs', 5))
}
# Check if user wants to force mock training
use_mock = data.get('use_mock', False)
# Use real trainer if available and not forced to use mock
if REAL_TRAINER_AVAILABLE and real_trainer and not use_mock:
print("Using real trainer with MNIST dataset")
results = real_trainer.train(params)
results['trainer_type'] = 'real'
results['dataset'] = 'MNIST'
else:
print("Using mock trainer with synthetic data")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
# Add gradient information for visualization (if not already included)
if 'gradient_info' not in results:
trainer = real_trainer if (REAL_TRAINER_AVAILABLE and real_trainer and not use_mock) else mock_trainer
results['gradient_info'] = {
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']),
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm'])
}
return jsonify(results)
except (TypeError, ValueError) as e:
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
except Exception as e:
print(f"Training error: {str(e)}")
# Fallback to mock trainer on any error
try:
print("Falling back to mock trainer due to error")
results = mock_trainer.train(params)
results['trainer_type'] = 'mock'
results['dataset'] = 'synthetic'
results['fallback_reason'] = str(e)
return jsonify(results)
except Exception as fallback_error:
return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500
@main.route('/api/privacy-budget', methods=['POST', 'OPTIONS'])
@cross_origin()
def calculate_privacy_budget():
if request.method == 'OPTIONS':
return jsonify({'status': 'ok'})
try:
data = request.json
if not data:
return jsonify({'error': 'No data provided'}), 400
params = {
'clipping_norm': float(data.get('clipping_norm', 1.0)),
'noise_multiplier': float(data.get('noise_multiplier', 1.0)),
'batch_size': int(data.get('batch_size', 64)),
'epochs': int(data.get('epochs', 5))
}
# Use real trainer's privacy calculation if available, otherwise use privacy calculator
if REAL_TRAINER_AVAILABLE and real_trainer:
epsilon = real_trainer._calculate_privacy_budget(params)
else:
epsilon = privacy_calculator.calculate_epsilon(params)
return jsonify({'epsilon': epsilon})
except (TypeError, ValueError) as e:
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400
except Exception as e:
return jsonify({'error': f'Server error: {str(e)}'}), 500
@main.route('/api/trainer-status', methods=['GET'])
@cross_origin()
def trainer_status():
"""Endpoint to check which trainer is being used."""
return jsonify({
'real_trainer_available': REAL_TRAINER_AVAILABLE,
'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock',
'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic'
})