denoising / upload_to_HF.py
BorisovMaksim's picture
fix app.py
08d9656
raw
history blame
1.93 kB
import os
import wandb
from huggingface_hub import HfApi
from pathlib import Path
import huggingface_hub
import ssl
import os
os.environ['CURL_CA_BUNDLE'] = ''
ssl._create_default_https_context = ssl._create_unverified_context
class Uploader:
def __init__(self, entity, project, run_name, repo_id, username):
self.entity = entity
self.project = project
self.run_name = run_name
self.hf_api = HfApi()
self.wandb_api = wandb.Api()
self.repo_id = repo_id
self.username = username
huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN'))
def get_model_from_wandb_run(self):
runs = self.wandb_api.runs(f"{self.entity}/{self.project}",
# order='+summary_metrics.train_pesq'
)
run = [run for run in runs if run.name == self.run_name][0]
artifacts = run.logged_artifacts()
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
artifact_dir = best_model.download()
model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix()
print(f"Model validation score = {best_model.metadata['Validation score']}")
return model_path
def upload_to_HF(self):
model_path = self.get_model_from_wandb_run()
self.hf_api.upload_file(
path_or_fileobj=model_path,
path_in_repo=Path(model_path).name,
repo_id=f'{self.username}/{self.repo_id}',
)
def create_repo(self):
self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True)
if __name__ == '__main__':
uploader = Uploader(entity='borisovmaksim',
project='denoising',
run_name='wav_normalization',
repo_id='demucs',
username='BorisovMaksim')
uploader.create_repo()
uploader.upload_to_HF()