lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.03 kB
import os
import torch
from s3prl.util.download import _urls_to_filepaths
from .expert import UpstreamExpert as _UpstreamExpert
def decoar2_custom(ckpt: str, refresh=False, *args, **kwargs):
if ckpt.startswith("http"):
ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
return _UpstreamExpert(ckpt, *args, **kwargs)
def decoar2_local(*args, **kwargs):
"""
The model from local ckpt
ckpt (str): PATH
feature_selection (str): 'c' (default) or 'z'
"""
return decoar2_custom(*args, **kwargs)
def decoar2_url(*args, **kwargs):
"""
The model from URL
ckpt (str): URL
"""
return decoar2_custom(*args, **kwargs)
def decoar2(*args, refresh=False, **kwargs):
"""
The apc standard model on 360hr
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs[
"ckpt"
] = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/checkpoint_decoar2.pt"
return decoar2_url(*args, refresh=refresh, **kwargs)