Joseph Feng
update link and fix assertion bug
52f268c
import typer
import torch
import subprocess
from pathlib import Path
from expert import UpstreamExpert
SUBMISSION_FILES = ["expert.py", "model.pt"]
SAMPLE_RATE = 16000
SECONDS = [2, 1.8, 3.7]
app = typer.Typer()
@app.command()
def validate():
# Check that all the expected files exist
for file in SUBMISSION_FILES:
if not Path(file).is_file():
raise ValueError(f"File {file} not found! Please include {file} in your submission")
try:
upstream = UpstreamExpert(ckpt="model.pt")
samples = [round(SAMPLE_RATE * sec) for sec in SECONDS]
wavs = [torch.rand(sample) for sample in samples]
results = upstream(wavs)
assert isinstance(results, dict)
tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"]
for task in tasks:
hidden_states = results.get(task, results["hidden_states"])
assert isinstance(hidden_states, list)
for state in hidden_states:
assert isinstance(state, torch.Tensor)
assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)"
assert state.shape == hidden_states[0].shape
downsample_rate = upstream.get_downsample_rates(task)
assert isinstance(downsample_rate, int)
assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate"
except:
print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream")
raise
typer.echo("All submission files validated!")
typer.echo("Now you can upload these files to huggingface's Hub.")
@app.command()
def upload(commit_message: str):
subprocess.call("git pull origin main".split())
subprocess.call(["git", "add", "."])
subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "])
subprocess.call(["git", "push"])
typer.echo("Upload successful!")
typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:")
typer.echo("1. Organization Name")
typer.echo("2. Repository Name")
typer.echo("3. Commit Hash (full 40 characters)")
typer.echo("These information can be shown by: python cli.py info")
@app.command()
def info():
result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True)
url = result.stdout.decode("utf-8").strip()
organization = url.split("/")[-2]
repo = url.split("/")[-1]
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True)
commit_hash = result.stdout.decode("utf-8").strip()
typer.echo(f"Organization Name: {organization}")
typer.echo(f"Repository Name: {repo}")
typer.echo(f"Commit Hash: {commit_hash}")
if __name__ == "__main__":
app()