| |
| |
| |
| |
| |
|
|
|
|
| """ |
| To run this script, from the root of the repo. Make sure to have Flask installed |
| |
| FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567 |
| # or if you have gunicorn |
| gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - |
| |
| """ |
| from collections import defaultdict |
| from functools import wraps |
| from hashlib import sha1 |
| import json |
| import math |
| from pathlib import Path |
| import random |
| import typing as tp |
|
|
| from flask import Flask, redirect, render_template, request, session, url_for |
|
|
| from audiocraft import train |
| from audiocraft.utils.samples.manager import get_samples_for_xps |
|
|
|
|
| SAMPLES_PER_PAGE = 8 |
| MAX_RATING = 5 |
| storage = Path(train.main.dora.dir / 'mos_storage') |
| storage.mkdir(exist_ok=True) |
| surveys = storage / 'surveys' |
| surveys.mkdir(exist_ok=True) |
| magma_root = Path(train.__file__).parent.parent |
| app = Flask('mos', static_folder=str(magma_root / 'scripts/static'), |
| template_folder=str(magma_root / 'scripts/templates')) |
| app.secret_key = b'audiocraft makes the best songs' |
|
|
|
|
| def normalize_path(path: Path): |
| """Just to make path a bit nicer, make them relative to the Dora root dir. |
| """ |
| path = path.resolve() |
| dora_dir = train.main.dora.dir.resolve() / 'xps' |
| return path.relative_to(dora_dir) |
|
|
|
|
| def get_full_path(normalized_path: Path): |
| """Revert `normalize_path`. |
| """ |
| return train.main.dora.dir.resolve() / 'xps' / normalized_path |
|
|
|
|
| def get_signature(xps: tp.List[str]): |
| """Return a signature for a list of XP signatures. |
| """ |
| return sha1(json.dumps(xps).encode()).hexdigest()[:10] |
|
|
|
|
| def ensure_logged(func): |
| """Ensure user is logged in. |
| """ |
| @wraps(func) |
| def _wrapped(*args, **kwargs): |
| user = session.get('user') |
| if user is None: |
| return redirect(url_for('login', redirect_to=request.url)) |
| return func(*args, **kwargs) |
| return _wrapped |
|
|
|
|
| @app.route('/login', methods=['GET', 'POST']) |
| def login(): |
| """Login user if not already, then redirect. |
| """ |
| user = session.get('user') |
| if user is None: |
| error = None |
| if request.method == 'POST': |
| user = request.form['user'] |
| if not user: |
| error = 'User cannot be empty' |
| if user is None or error: |
| return render_template('login.html', error=error) |
| assert user |
| session['user'] = user |
| redirect_to = request.args.get('redirect_to') |
| if redirect_to is None: |
| redirect_to = url_for('index') |
| return redirect(redirect_to) |
|
|
|
|
| @app.route('/', methods=['GET', 'POST']) |
| @ensure_logged |
| def index(): |
| """Offer to create a new study. |
| """ |
| errors = [] |
| if request.method == 'POST': |
| xps_or_grids = [part.strip() for part in request.form['xps'].split()] |
| xps = set() |
| for xp_or_grid in xps_or_grids: |
| xp_path = train.main.dora.dir / 'xps' / xp_or_grid |
| if xp_path.exists(): |
| xps.add(xp_or_grid) |
| continue |
| grid_path = train.main.dora.dir / 'grids' / xp_or_grid |
| if grid_path.exists(): |
| for child in grid_path.iterdir(): |
| if child.is_symlink(): |
| xps.add(child.name) |
| continue |
| errors.append(f'{xp_or_grid} is neither an XP nor a grid!') |
| assert xps or errors |
| blind = 'true' if request.form.get('blind') == 'on' else 'false' |
| xps = list(xps) |
| if not errors: |
| signature = get_signature(xps) |
| manifest = { |
| 'xps': xps, |
| } |
| survey_path = surveys / signature |
| survey_path.mkdir(exist_ok=True) |
| with open(survey_path / 'manifest.json', 'w') as f: |
| json.dump(manifest, f, indent=2) |
| return redirect(url_for('survey', blind=blind, signature=signature)) |
| return render_template('index.html', errors=errors) |
|
|
|
|
| @app.route('/survey/<signature>', methods=['GET', 'POST']) |
| @ensure_logged |
| def survey(signature): |
| success = request.args.get('success', False) |
| seed = int(request.args.get('seed', 4321)) |
| blind = request.args.get('blind', 'false') in ['true', 'on', 'True'] |
| exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True'] |
| exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True'] |
| max_epoch = int(request.args.get('max_epoch', '-1')) |
| survey_path = surveys / signature |
| assert survey_path.exists(), survey_path |
|
|
| user = session['user'] |
| result_folder = survey_path / 'results' |
| result_folder.mkdir(exist_ok=True) |
| result_file = result_folder / f'{user}_{seed}.json' |
|
|
| with open(survey_path / 'manifest.json') as f: |
| manifest = json.load(f) |
|
|
| xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']] |
| names, ref_name = train.main.get_names(xps) |
|
|
| samples_kwargs = { |
| 'exclude_prompted': exclude_prompted, |
| 'exclude_unprompted': exclude_unprompted, |
| 'max_epoch': max_epoch, |
| } |
| matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) |
| models_by_id = { |
| id: [{ |
| 'xp': xps[idx], |
| 'xp_name': names[idx], |
| 'model_id': f'{xps[idx].sig}-{sample.id}', |
| 'sample': sample, |
| 'is_prompted': sample.prompt is not None, |
| 'errors': [], |
| } for idx, sample in enumerate(samples)] |
| for id, samples in matched_samples.items() |
| } |
| experiments = [ |
| {'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch} |
| for idx, xp in enumerate(xps) |
| ] |
|
|
| keys = list(matched_samples.keys()) |
| keys.sort() |
| rng = random.Random(seed) |
| rng.shuffle(keys) |
| model_ids = keys[:SAMPLES_PER_PAGE] |
|
|
| if blind: |
| for key in model_ids: |
| rng.shuffle(models_by_id[key]) |
|
|
| ok = True |
| if request.method == 'POST': |
| all_samples_results = [] |
| for id in model_ids: |
| models = models_by_id[id] |
| result = { |
| 'id': id, |
| 'is_prompted': models[0]['is_prompted'], |
| 'models': {} |
| } |
| all_samples_results.append(result) |
| for model in models: |
| rating = request.form[model['model_id']] |
| if rating: |
| rating = int(rating) |
| assert rating <= MAX_RATING and rating >= 1 |
| result['models'][model['xp'].sig] = rating |
| model['rating'] = rating |
| else: |
| ok = False |
| model['errors'].append('Please rate this model.') |
| if ok: |
| result = { |
| 'results': all_samples_results, |
| 'seed': seed, |
| 'user': user, |
| 'blind': blind, |
| 'exclude_prompted': exclude_prompted, |
| 'exclude_unprompted': exclude_unprompted, |
| } |
| print(result) |
| with open(result_file, 'w') as f: |
| json.dump(result, f) |
| seed = seed + 1 |
| return redirect(url_for( |
| 'survey', signature=signature, blind=blind, seed=seed, |
| exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, |
| max_epoch=max_epoch, success=True)) |
|
|
| ratings = list(range(1, MAX_RATING + 1)) |
| return render_template( |
| 'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success, |
| exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, |
| experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[], |
| ref_name=ref_name, already_filled=result_file.exists()) |
|
|
|
|
| @app.route('/audio/<path:path>') |
| def audio(path: str): |
| full_path = Path('/') / path |
| assert full_path.suffix in [".mp3", ".wav"] |
| return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'} |
|
|
|
|
| def mean(x): |
| return sum(x) / len(x) |
|
|
|
|
| def std(x): |
| m = mean(x) |
| return math.sqrt(sum((i - m)**2 for i in x) / len(x)) |
|
|
|
|
| @app.route('/results/<signature>') |
| @ensure_logged |
| def results(signature): |
|
|
| survey_path = surveys / signature |
| assert survey_path.exists(), survey_path |
| result_folder = survey_path / 'results' |
| result_folder.mkdir(exist_ok=True) |
|
|
| |
| ratings_per_model = defaultdict(list) |
| users = [] |
| for result_file in result_folder.iterdir(): |
| if result_file.suffix != '.json': |
| continue |
| with open(result_file) as f: |
| results = json.load(f) |
| users.append(results['user']) |
| for result in results['results']: |
| for sig, rating in result['models'].items(): |
| ratings_per_model[sig].append(rating) |
|
|
| fmt = '{:.2f}' |
| models = [] |
| for model in sorted(ratings_per_model.keys()): |
| ratings = ratings_per_model[model] |
|
|
| models.append({ |
| 'sig': model, |
| 'samples': len(ratings), |
| 'mean_rating': fmt.format(mean(ratings)), |
| |
| |
| 'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5), |
| }) |
| return render_template('results.html', signature=signature, models=models, users=users) |
|
|