|
from flask import Blueprint, render_template, jsonify, request, current_app |
|
from app.training.mock_trainer import MockTrainer |
|
from app.training.privacy_calculator import PrivacyCalculator |
|
from flask_cors import cross_origin |
|
import os |
|
|
|
|
|
try: |
|
from app.training.simplified_real_trainer import SimplifiedRealTrainer as RealTrainer |
|
REAL_TRAINER_AVAILABLE = True |
|
print("Simplified real trainer available - will use MNIST dataset") |
|
except ImportError as e: |
|
print(f"Real trainer not available ({e}) - trying simplified version") |
|
try: |
|
from app.training.real_trainer import RealTrainer |
|
REAL_TRAINER_AVAILABLE = True |
|
print("Full real trainer available - will use MNIST dataset") |
|
except ImportError as e2: |
|
print(f"No real trainer available ({e2}) - using mock trainer") |
|
REAL_TRAINER_AVAILABLE = False |
|
|
|
main = Blueprint('main', __name__) |
|
mock_trainer = MockTrainer() |
|
privacy_calculator = PrivacyCalculator() |
|
|
|
|
|
if REAL_TRAINER_AVAILABLE: |
|
try: |
|
real_trainer = RealTrainer() |
|
print("Real trainer initialized successfully") |
|
except Exception as e: |
|
print(f"Failed to initialize real trainer: {e}") |
|
REAL_TRAINER_AVAILABLE = False |
|
real_trainer = None |
|
else: |
|
real_trainer = None |
|
|
|
@main.route('/') |
|
def index(): |
|
return render_template('index.html') |
|
|
|
@main.route('/learning') |
|
def learning(): |
|
return render_template('learning.html') |
|
|
|
@main.route('/api/train', methods=['POST', 'OPTIONS']) |
|
@cross_origin() |
|
def train(): |
|
if request.method == 'OPTIONS': |
|
return jsonify({'status': 'ok'}) |
|
|
|
try: |
|
data = request.json |
|
if not data: |
|
return jsonify({'error': 'No data provided'}), 400 |
|
|
|
params = { |
|
'clipping_norm': float(data.get('clipping_norm', 1.0)), |
|
'noise_multiplier': float(data.get('noise_multiplier', 1.0)), |
|
'batch_size': int(data.get('batch_size', 64)), |
|
'learning_rate': float(data.get('learning_rate', 0.01)), |
|
'epochs': int(data.get('epochs', 5)) |
|
} |
|
|
|
|
|
use_mock = data.get('use_mock', False) |
|
|
|
|
|
if REAL_TRAINER_AVAILABLE and real_trainer and not use_mock: |
|
print("Using real trainer with MNIST dataset") |
|
results = real_trainer.train(params) |
|
results['trainer_type'] = 'real' |
|
results['dataset'] = 'MNIST' |
|
else: |
|
print("Using mock trainer with synthetic data") |
|
results = mock_trainer.train(params) |
|
results['trainer_type'] = 'mock' |
|
results['dataset'] = 'synthetic' |
|
|
|
|
|
if 'gradient_info' not in results: |
|
trainer = real_trainer if (REAL_TRAINER_AVAILABLE and real_trainer and not use_mock) else mock_trainer |
|
results['gradient_info'] = { |
|
'before_clipping': trainer.generate_gradient_norms(params['clipping_norm']), |
|
'after_clipping': trainer.generate_clipped_gradients(params['clipping_norm']) |
|
} |
|
|
|
return jsonify(results) |
|
except (TypeError, ValueError) as e: |
|
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400 |
|
except Exception as e: |
|
print(f"Training error: {str(e)}") |
|
|
|
try: |
|
print("Falling back to mock trainer due to error") |
|
results = mock_trainer.train(params) |
|
results['trainer_type'] = 'mock' |
|
results['dataset'] = 'synthetic' |
|
results['fallback_reason'] = str(e) |
|
return jsonify(results) |
|
except Exception as fallback_error: |
|
return jsonify({'error': f'Server error: {str(fallback_error)}'}), 500 |
|
|
|
@main.route('/api/privacy-budget', methods=['POST', 'OPTIONS']) |
|
@cross_origin() |
|
def calculate_privacy_budget(): |
|
if request.method == 'OPTIONS': |
|
return jsonify({'status': 'ok'}) |
|
|
|
try: |
|
data = request.json |
|
if not data: |
|
return jsonify({'error': 'No data provided'}), 400 |
|
|
|
params = { |
|
'clipping_norm': float(data.get('clipping_norm', 1.0)), |
|
'noise_multiplier': float(data.get('noise_multiplier', 1.0)), |
|
'batch_size': int(data.get('batch_size', 64)), |
|
'epochs': int(data.get('epochs', 5)) |
|
} |
|
|
|
|
|
if REAL_TRAINER_AVAILABLE and real_trainer: |
|
epsilon = real_trainer._calculate_privacy_budget(params) |
|
else: |
|
epsilon = privacy_calculator.calculate_epsilon(params) |
|
|
|
return jsonify({'epsilon': epsilon}) |
|
except (TypeError, ValueError) as e: |
|
return jsonify({'error': f'Invalid parameter values: {str(e)}'}), 400 |
|
except Exception as e: |
|
return jsonify({'error': f'Server error: {str(e)}'}), 500 |
|
|
|
@main.route('/api/trainer-status', methods=['GET']) |
|
@cross_origin() |
|
def trainer_status(): |
|
"""Endpoint to check which trainer is being used.""" |
|
return jsonify({ |
|
'real_trainer_available': REAL_TRAINER_AVAILABLE, |
|
'current_trainer': 'real' if REAL_TRAINER_AVAILABLE else 'mock', |
|
'dataset': 'MNIST' if REAL_TRAINER_AVAILABLE else 'synthetic' |
|
}) |