Spaces:
Sleeping
Sleeping
mv batch option to argparse
Browse files- app.py +6 -2
- javascript/app.js +2 -1
app.py
CHANGED
|
@@ -20,7 +20,6 @@ from midi_model import MIDIModel, MIDIModelConfig
|
|
| 20 |
from midi_synthesizer import MidiSynthesizer
|
| 21 |
|
| 22 |
MAX_SEED = np.iinfo(np.int32).max
|
| 23 |
-
OUTPUT_BATCH_SIZE = 8
|
| 24 |
in_space = os.getenv("SYSTEM") == "spaces"
|
| 25 |
|
| 26 |
|
|
@@ -305,7 +304,10 @@ def load_javascript(dir="javascript"):
|
|
| 305 |
javascript = ""
|
| 306 |
for path in scripts_list:
|
| 307 |
with open(path, "r", encoding="utf8") as jsfile:
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
| 309 |
template_response_ori = gr.routes.templates.TemplateResponse
|
| 310 |
|
| 311 |
def template_response(*args, **kwargs):
|
|
@@ -344,8 +346,10 @@ if __name__ == "__main__":
|
|
| 344 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
| 345 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 346 |
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
|
|
|
|
| 347 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
| 348 |
opt = parser.parse_args()
|
|
|
|
| 349 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 350 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
| 351 |
synthesizer = MidiSynthesizer(soundfont_path)
|
|
|
|
| 20 |
from midi_synthesizer import MidiSynthesizer
|
| 21 |
|
| 22 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 23 |
in_space = os.getenv("SYSTEM") == "spaces"
|
| 24 |
|
| 25 |
|
|
|
|
| 304 |
javascript = ""
|
| 305 |
for path in scripts_list:
|
| 306 |
with open(path, "r", encoding="utf8") as jsfile:
|
| 307 |
+
js_content = jsfile.read()
|
| 308 |
+
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
|
| 309 |
+
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
|
| 310 |
+
javascript += f"\n<!-- {path} --><script>{js_content}</script>"
|
| 311 |
template_response_ori = gr.routes.templates.TemplateResponse
|
| 312 |
|
| 313 |
def template_response(*args, **kwargs):
|
|
|
|
| 346 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
| 347 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 348 |
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
|
| 349 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
| 350 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
| 351 |
opt = parser.parse_args()
|
| 352 |
+
OUTPUT_BATCH_SIZE = opt.batch
|
| 353 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 354 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
| 355 |
synthesizer = MidiSynthesizer(soundfont_path)
|
javascript/app.js
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
const MIDI_OUTPUT_BATCH_SIZE=
|
|
|
|
| 2 |
|
| 3 |
/**
|
| 4 |
* 自动绕过 shadowRoot 的 querySelector
|
|
|
|
| 1 |
+
const MIDI_OUTPUT_BATCH_SIZE=4;
|
| 2 |
+
//Do not change MIDI_OUTPUT_BATCH_SIZE. It will be automatically replaced.
|
| 3 |
|
| 4 |
/**
|
| 5 |
* 自动绕过 shadowRoot 的 querySelector
|