import typer import torch import subprocess from pathlib import Path from expert import UpstreamExpert SUBMISSION_FILES = ["expert.py", "model.pt"] SAMPLE_RATE = 16000 SECONDS = [2, 1.8, 3.7] app = typer.Typer() @app.command() def validate(): # Check that all the expected files exist for file in SUBMISSION_FILES: if not Path(file).is_file(): raise ValueError(f"File {file} not found! Please include {file} in your submission") try: upstream = UpstreamExpert(ckpt="model.pt") samples = [round(SAMPLE_RATE * sec) for sec in SECONDS] wavs = [torch.rand(sample) for sample in samples] results = upstream(wavs) assert isinstance(results, dict) tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"] for task in tasks: hidden_states = results.get(task, results["hidden_states"]) assert isinstance(hidden_states, list) for state in hidden_states: assert isinstance(state, torch.Tensor) assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)" assert state.shape == hidden_states[0].shape downsample_rate = upstream.get_downsample_rates(task) assert isinstance(downsample_rate, int) assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate" except: print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream") raise typer.echo("All submission files validated!") typer.echo("Now you can upload these files to huggingface's Hub.") @app.command() def upload(commit_message: str): subprocess.call("git pull origin main".split()) subprocess.call(["git", "add", "."]) subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "]) subprocess.call(["git", "push"]) typer.echo("Upload successful!") typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:") typer.echo("1. Organization Name") typer.echo("2. Repository Name") typer.echo("3. Commit Hash (full 40 characters)") typer.echo("These information can be shown by: python cli.py info") @app.command() def info(): result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True) url = result.stdout.decode("utf-8").strip() organization = url.split("/")[-2] repo = url.split("/")[-1] result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True) commit_hash = result.stdout.decode("utf-8").strip() typer.echo(f"Organization Name: {organization}") typer.echo(f"Repository Name: {repo}") typer.echo(f"Commit Hash: {commit_hash}") if __name__ == "__main__": app()