Spaces:
Runtime error
Runtime error
import os | |
import wandb | |
from huggingface_hub import HfApi | |
from pathlib import Path | |
import huggingface_hub | |
import ssl | |
import os | |
os.environ['CURL_CA_BUNDLE'] = '' | |
ssl._create_default_https_context = ssl._create_unverified_context | |
class Uploader: | |
def __init__(self, entity, project, run_name, repo_id, username): | |
self.entity = entity | |
self.project = project | |
self.run_name = run_name | |
self.hf_api = HfApi() | |
self.wandb_api = wandb.Api() | |
self.repo_id = repo_id | |
self.username = username | |
huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN')) | |
def get_model_from_wandb_run(self): | |
runs = self.wandb_api.runs(f"{self.entity}/{self.project}", | |
# order='+summary_metrics.train_pesq' | |
) | |
run = [run for run in runs if run.name == self.run_name][0] | |
artifacts = run.logged_artifacts() | |
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0] | |
artifact_dir = best_model.download() | |
model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix() | |
print(f"Model validation score = {best_model.metadata['Validation score']}") | |
return model_path | |
def upload_to_HF(self): | |
model_path = self.get_model_from_wandb_run() | |
self.hf_api.upload_file( | |
path_or_fileobj=model_path, | |
path_in_repo=Path(model_path).name, | |
repo_id=f'{self.username}/{self.repo_id}', | |
) | |
def create_repo(self): | |
self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True) | |
if __name__ == '__main__': | |
uploader = Uploader(entity='borisovmaksim', | |
project='denoising', | |
run_name='wav_normalization', | |
repo_id='demucs', | |
username='BorisovMaksim') | |
uploader.create_repo() | |
uploader.upload_to_HF() | |