EMusicGen / utils.py
admin
upl base
86eaed9
raw
history blame
No virus
2.05 kB
import os
import sys
import time
import torch
import requests
import subprocess
from tqdm import tqdm
from modelscope.hub.api import HubApi
from modelscope import snapshot_download
HubApi().login(os.getenv("ms_app_key"))
TEMP_DIR = "./flagged"
WEIGHTS_DIR = snapshot_download("monetjoe/EMusicGen", cache_dir="./__pycache__")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATCH_LENGTH = 128 # Patch Length
PATCH_SIZE = 32 # Patch Size
PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
PATCH_SAMPLING_BATCH_SIZE = 0 # Batch size for training patch, 0 for full context
LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint
SHARE_WEIGHTS = False # Whether to share weights between the encoder and decoder
def download(filename: str, url: str):
try:
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024
with open(filename, "wb") as file, tqdm(
desc=f"Downloading {filename} from {url}...",
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"Error: {e}")
time.sleep(10)
download(filename, url)
if sys.platform.startswith("linux"):
apkname = "MuseScore.AppImage"
extra_dir = "squashfs-root"
download(
filename=apkname,
url="https://www.modelscope.cn/studio/MuGeminorum/piano_transcription/resolve/master/MuseScore.AppImage",
)
if not os.path.exists(extra_dir):
subprocess.run(["chmod", "+x", f"./{apkname}"])
subprocess.run([f"./{apkname}", "--appimage-extract"])
MSCORE = f"./{extra_dir}/AppRun"
os.environ["QT_QPA_PLATFORM"] = "offscreen"
else:
MSCORE = os.getenv("mscore")