Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import sys | |
import traceback | |
import multiprocessing | |
torch_dml_device = None | |
if __name__ == '__main__': | |
server = None | |
multiprocessing.freeze_support() | |
PROD = 'xVASynth.exe' in os.listdir(".") | |
# Saves me having to do backend re-compilations for every little UI hotfix | |
with open(f'{"./resources/app" if PROD else "."}/javascript/script.js', encoding="utf8") as f: | |
lines = f.read().split("\n") | |
APP_VERSION = lines[1].split('"v')[1].split('"')[0] | |
# Imports and logger setup | |
# ======================== | |
try: | |
import python.pyinstaller_imports | |
import numpy | |
import logging | |
from logging.handlers import RotatingFileHandler | |
import json | |
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from socketserver import ThreadingMixIn | |
from python.audio_post import run_audio_post, prepare_input_audio, mp_ffmpeg_output, normalize_audio, start_microphone_recording, move_recorded_file | |
import ffmpeg | |
except: | |
print(traceback.format_exc()) | |
with open("./DEBUG_err_imports.txt", "w+") as f: | |
f.write(traceback.format_exc()) | |
# Pyinstaller hack | |
# ================ | |
try: | |
def script_method(fn, _rcb=None): | |
return fn | |
def script(obj, optimize=True, _frames_up=0, _rcb=None): | |
return obj | |
import torch.jit | |
torch.jit.script_method = script_method | |
torch.jit.script = script | |
import torch | |
import tqdm | |
import regex | |
except: | |
with open("./DEBUG_err_import_torch.txt", "w+") as f: | |
f.write(traceback.format_exc()) | |
# ================ | |
CPU_ONLY = not torch.cuda.is_available() | |
try: | |
logger = logging.getLogger('serverLog') | |
logger.setLevel(logging.DEBUG) | |
server_log_path = f'{os.path.dirname(os.path.realpath(__file__))}/{"../../../" if PROD else ""}/server.log' | |
fh = RotatingFileHandler(server_log_path, maxBytes=2*1024*1024, backupCount=5) | |
fh.setLevel(logging.DEBUG) | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.ERROR) | |
formatter = logging.Formatter('%(asctime)s - %(message)s') | |
fh.setFormatter(formatter) | |
ch.setFormatter(formatter) | |
logger.addHandler(fh) | |
logger.addHandler(ch) | |
logger.info(f'New session. Version: {APP_VERSION}. Installation: {"CPU" if CPU_ONLY else "CPU+GPU"} | Prod: {PROD} | Log path: {server_log_path}') | |
logger.orig_info = logger.info | |
def prefixed_log (msg): | |
logger.info(f'{logger.logging_prefix}{msg}') | |
def set_logger_prefix (prefix=""): | |
if len(prefix): | |
logger.logging_prefix = f'[{prefix}]: ' | |
logger.log = prefixed_log | |
else: | |
logger.log = logger.orig_info | |
logger.set_logger_prefix = set_logger_prefix | |
logger.set_logger_prefix("") | |
except: | |
with open("./DEBUG_err_logger.txt", "w+") as f: | |
f.write(traceback.format_exc()) | |
try: | |
logger.info(traceback.format_exc()) | |
except: | |
pass | |
if CPU_ONLY: | |
try: | |
import torch_directml | |
torch_dml_device = torch_directml.device() | |
logger.info("Successfully got the torch DirectML device") | |
except Exception as e: | |
# I've implemented support for DirectML, but at the time of writing (08/04/2023, v0.1.13.1.dev230301), it's hella broken... | |
# Not a single model can successfully .forward() when switching to DirectML device from cpu. I'm leaving in the code however, | |
# as I'd still like to add support for it once things are more stable. This try/catch should run ok when it's installed | |
torch_dml_device = torch.device("cpu") | |
logger.exception("Failed to get torch DirectML; falling back to cpu device") | |
# ======================== | |
try: | |
from python.plugins_manager import PluginManager | |
plugin_manager = PluginManager(APP_VERSION, PROD, CPU_ONLY, logger) | |
active_plugins = plugin_manager.get_active_plugins_count() | |
logger.info(f'Plugin manager loaded. {active_plugins} active plugins.') | |
except: | |
logger.info("Plugin manager FAILED.") | |
logger.info(traceback.format_exc()) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["start"]["pre"], event="pre start", data=None) | |
# ======================== Models manager | |
modelsPaths = {} | |
try: | |
from python.models_manager import ModelsManager | |
models_manager = ModelsManager(logger, PROD, device="cpu") | |
except: | |
logger.info("Models manager failed to initialize") | |
logger.info(traceback.format_exc()) | |
# ======================== | |
print("Models ready") | |
logger.info("Models ready") | |
# Server | |
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): | |
pass | |
class Handler(BaseHTTPRequestHandler): | |
def _set_response(self): | |
self.send_response(200) | |
self.send_header("Content-Type", "text/html") | |
self.end_headers() | |
def do_GET(self): | |
returnString = "[DEBUG] Get request for {}".format(self.path).encode("utf-8") | |
logger.info(returnString) | |
self._set_response() | |
self.wfile.write(returnString) | |
def do_POST(self): | |
global modelsPaths | |
post_data = "" | |
try: | |
content_length = int(self.headers['Content-Length']) | |
post_data = json.loads(self.rfile.read(content_length).decode('utf-8')) if content_length else {} | |
req_response = "POST request for {}".format(self.path) | |
print("POST") | |
print(self.path) | |
# For headless mode | |
if self.path == "/setAvailableVoices": | |
modelsPaths = json.loads(post_data["modelsPaths"]) | |
if self.path == "/getAvailableVoices": | |
models = {} | |
for gameId in modelsPaths.keys(): | |
models[gameId] = [] | |
modelJSONs = sorted(os.listdir(modelsPaths[gameId])) | |
for fname in modelJSONs: | |
if fname.endswith(".json"): | |
with open(f'{modelsPaths[gameId]}/{fname}', "r") as f: | |
jsons = f.read() | |
metadata = json.loads(jsons) | |
models[gameId].append({ | |
"modelType": metadata["modelType"], | |
"author": metadata["author"] if "author" in metadata else "", | |
"emb_size": metadata["emb_size"] if "emb_size" in metadata else 1, | |
"voiceId": metadata["games"][0]["voiceId"], | |
"voiceName": metadata["games"][0]["voiceName"], | |
"gender": metadata["games"][0]["gender"] if "gender" in metadata["games"][0] else "other", | |
"emb_i": metadata["games"][0]["emb_i"] if "emb_i" in metadata["games"][0] else 0 | |
}) | |
req_response = json.dumps(models) | |
if self.path == "/setVocoder": | |
logger.info("POST {}".format(self.path)) | |
logger.info(post_data) | |
vocoder = post_data["vocoder"] | |
modelPath = post_data["modelPath"] | |
hifi_gan = "waveglow" not in vocoder | |
if vocoder=="qnd": | |
req_response = models_manager.load_model("hifigan", f'{"./resources/app" if PROD else "."}/python/hifigan/hifi.pt') | |
elif not hifi_gan: | |
req_response = models_manager.load_model(vocoder, modelPath) | |
req_response = "" if req_response is None else req_response | |
if self.path == "/stopServer": | |
logger.info("POST {}".format(self.path)) | |
logger.info("STOPPING SERVER") | |
server.shutdown() | |
sys.exit() | |
if self.path == "/normalizeAudio": | |
input_path = post_data["input_path"] | |
output_path = post_data["output_path"] | |
req_response = normalize_audio(input_path, output_path) | |
if self.path == "/customEvent": | |
logger.info("POST {}".format(self.path)) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["custom-event"], event="custom-event", data=post_data) | |
if self.path == "/setDevice": | |
logger.info("POST {}".format(self.path)) | |
logger.info(post_data) | |
if post_data["device"] == "cpu": | |
logger.info("Setting torch device to CPU") | |
device = torch.device("cpu") | |
elif CPU_ONLY: | |
logger.info("Setting torch device to DirectML") | |
device = torch_dml_device | |
else: | |
logger.info("Setting torch device to CUDA") | |
device = torch.device("cuda:0") | |
models_manager.set_device(device) | |
if self.path == "/loadModel": | |
logger.info("POST {}".format(self.path)) | |
logger.info(post_data) | |
ckpt = post_data["model"] | |
modelType = post_data["modelType"] | |
instance_index = post_data["instance_index"] if "instance_index" in post_data else 0 | |
modelType = modelType.lower().replace(".", "_").replace(" ", "") | |
post_data["pluginsContext"] = json.loads(post_data["pluginsContext"]) | |
n_speakers = post_data["model_speakers"] if "model_speakers" in post_data else None | |
base_lang = post_data["base_lang"] if "base_lang" in post_data else None | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["load-model"]["pre"], event="pre load-model", data=post_data) | |
models_manager.load_model(modelType, ckpt+".pt", instance_index=instance_index, n_speakers=n_speakers, base_lang=base_lang) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["load-model"]["post"], event="post load-model", data=post_data) | |
if modelType=="fastpitch1_1": | |
models_manager.models_bank["fastpitch1_1"][instance_index].init_arpabet_dicts() | |
if self.path == "/getG2P": | |
text = post_data["text"] | |
base_lang = post_data["base_lang"] | |
model = models_manager.models("xVAPitch", instance_index=0) | |
returnString = model.getG2P(text, base_lang) | |
req_response = returnString | |
if self.path == "/synthesizeSimple": | |
logger.info("POST {}".format(self.path)) | |
text = post_data["sequence"] | |
instance_index = post_data["instance_index"] if "instance_index" in post_data else 0 | |
out_path = post_data["outfile"] | |
base_lang = post_data["base_lang"] if "base_lang" in post_data else None | |
base_emb = post_data["base_emb"] if "base_emb" in post_data else None | |
useCleanup = post_data["useCleanup"] if "useCleanup" in post_data else None | |
model = models_manager.models("xvapitch", instance_index=instance_index) | |
req_response = model.infer(plugin_manager, text, out_path, vocoder=None, \ | |
speaker_i=None, editor_data=None, pace=None, old_sequence=None, \ | |
globalAmplitudeModifier=None, base_lang=base_lang, base_emb=base_emb, useSR=False, useCleanup=useCleanup) | |
if self.path == "/synthesize": | |
logger.info("POST {}".format(self.path)) | |
post_data["pluginsContext"] = json.loads(post_data["pluginsContext"]) | |
instance_index = post_data["instance_index"] if "instance_index" in post_data else 0 | |
# Handle the case where the vocoder remains selected on app start-up, with auto-HiFi turned off, but no setVocoder call is made before synth | |
continue_synth = True | |
if "waveglow" in post_data["vocoder"]: | |
waveglowPath = post_data["waveglowPath"] | |
req_response = models_manager.load_model(post_data["vocoder"], waveglowPath, instance_index=instance_index) | |
if req_response=="ENOENT": | |
continue_synth = False | |
device = post_data["device"] if "device" in post_data else models_manager.device_label | |
device = torch.device("cpu") if device=="cpu" else (torch_dml_device if CPU_ONLY else torch.device("cuda:0")) | |
models_manager.set_device(device, instance_index=instance_index) | |
if continue_synth: | |
plugin_manager.set_context(post_data["pluginsContext"]) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["synth-line"]["pre"], event="pre synth-line", data=post_data) | |
modelType = post_data["modelType"] | |
text = post_data["sequence"] | |
pace = float(post_data["pace"]) | |
out_path = post_data["outfile"] | |
base_lang = post_data["base_lang"] if "base_lang" in post_data else None | |
base_emb = post_data["base_emb"] if "base_emb" in post_data else None | |
pitch = post_data["pitch"] if "pitch" in post_data else None | |
energy = post_data["energy"] if "energy" in post_data else None | |
emAngry = post_data["emAngry"] if "emAngry" in post_data else None | |
emHappy = post_data["emHappy"] if "emHappy" in post_data else None | |
emSad = post_data["emSad"] if "emSad" in post_data else None | |
emSurprise = post_data["emSurprise"] if "emSurprise" in post_data else None | |
editorStyles = post_data["editorStyles"] if "editorStyles" in post_data else None | |
duration = post_data["duration"] if "duration" in post_data else None | |
speaker_i = post_data["speaker_i"] if "speaker_i" in post_data else None | |
useSR = post_data["useSR"] if "useSR" in post_data else None | |
useCleanup = post_data["useCleanup"] if "useCleanup" in post_data else None | |
vocoder = post_data["vocoder"] | |
globalAmplitudeModifier = float(post_data["globalAmplitudeModifier"]) if "globalAmplitudeModifier" in post_data else None | |
editor_data = [pitch, duration, energy, emAngry, emHappy, emSad, emSurprise, editorStyles] | |
old_sequence = post_data["old_sequence"] if "old_sequence" in post_data else None | |
model = models_manager.models(modelType.lower().replace(".", "_").replace(" ", ""), instance_index=instance_index) | |
req_response = model.infer(plugin_manager, text, out_path, vocoder=vocoder, \ | |
speaker_i=speaker_i, editor_data=editor_data, pace=pace, old_sequence=old_sequence, \ | |
globalAmplitudeModifier=globalAmplitudeModifier, base_lang=base_lang, base_emb=base_emb, useSR=useSR, useCleanup=useCleanup) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["synth-line"]["post"], event="post synth-line", data=post_data) | |
if self.path == "/synthesize_batch": | |
post_data["pluginsContext"] = json.loads(post_data["pluginsContext"]) | |
plugin_manager.set_context(post_data["pluginsContext"]) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["batch-synth-line"]["pre"], event="pre batch-synth-line", data=post_data) | |
modelType = post_data["modelType"] | |
linesBatch = post_data["linesBatch"] | |
speaker_i = post_data["speaker_i"] | |
vocoder = post_data["vocoder"] | |
outputJSON = post_data["outputJSON"] | |
useSR = post_data["useSR"] | |
useCleanup = post_data["useCleanup"] | |
with torch.no_grad(): | |
try: | |
model = models_manager.models(modelType.lower().replace(".", "_").replace(" ", "")) | |
req_response = model.infer_batch(plugin_manager, linesBatch, outputJSON=outputJSON, vocoder=vocoder, speaker_i=speaker_i, useSR=useSR, useCleanup=useCleanup) | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e): | |
req_response = "CUDA OOM" | |
else: | |
req_response = traceback.format_exc() | |
logger.info(req_response) | |
except: | |
e = traceback.format_exc() | |
if "CUDA out of memory" in str(e): | |
req_response = "CUDA OOM" | |
else: | |
req_response = e | |
logger.info(e) | |
post_data["req_response"] = req_response | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["batch-synth-line"]["post"], event="post batch-synth-line", data=post_data) | |
if self.path == "/runSpeechToSpeech": | |
logger.info("POST {}".format(self.path)) | |
input_path = post_data["input_path"] | |
style_emb = post_data["style_emb"] | |
options = post_data["options"] | |
audio_out_path = post_data["audio_out_path"] | |
useSR = post_data["useSR"] | |
useCleanup = post_data["useCleanup"] | |
vc_strength = post_data["vc_strength"] | |
removeNoise = post_data["removeNoise"] | |
removeNoiseStrength = post_data["removeNoiseStrength"] | |
final_path = prepare_input_audio(PROD, logger, input_path, removeNoise, removeNoiseStrength) | |
models_manager.init_model("speaker_rep") | |
models_manager.load_model("speaker_rep", f'{"./resources/app" if PROD else "."}/python/xvapitch/speaker_rep/speaker_rep.pt') | |
try: | |
out = models_manager.models("xvapitch").run_speech_to_speech(final_path, audio_out_path.replace(".wav", "_tempS2S.wav"), style_emb, models_manager, plugin_manager, vc_strength=vc_strength, useSR=useSR, useCleanup=useCleanup) | |
if out=="TOO_SHORT": | |
req_response = "TOO_SHORT" | |
else: | |
data_out = "" | |
req_response = data_out | |
# For use by /outputAudio | |
post_data["input_path"] = audio_out_path.replace(".wav", "_tempS2S.wav") | |
post_data["output_path"] = audio_out_path | |
except ValueError: | |
req_response = traceback.format_exc() | |
logger.info(req_response) | |
except RuntimeError: | |
req_response = traceback.format_exc() | |
logger.info(req_response) | |
except Exception as e: | |
req_response = traceback.format_exc() | |
logger.info(req_response) | |
logger.info(repr(e)) | |
if self.path == "/batchOutputAudio": | |
input_paths = post_data["input_paths"] | |
output_paths = post_data["output_paths"] | |
processes = post_data["processes"] | |
options = json.loads(post_data["options"]) | |
# For plugins | |
extraInfo = {} | |
if "extraInfo" in post_data: | |
extraInfo = json.loads(post_data["extraInfo"]) | |
extraInfo["pluginsContext"] = json.loads(post_data["pluginsContext"]) | |
extraInfo["audio_options"] = options | |
extraInfo["input_paths"] = input_paths | |
extraInfo["output_paths"] = output_paths | |
extraInfo["processes"] = processes | |
extraInfo["ffmpeg"] = ffmpeg | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["mp-output-audio"]["pre"], event="pre mp-output-audio", data=extraInfo) | |
req_response = mp_ffmpeg_output(PROD, logger, processes, input_paths, output_paths, options) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["mp-output-audio"]["post"], event="post mp-output-audio", data=extraInfo) | |
if self.path == "/outputAudio" or (self.path == "/runSpeechToSpeech" and req_response==""): | |
isBatchMode = post_data["isBatchMode"] | |
if not isBatchMode: | |
logger.info("POST /outputAudio") | |
input_path = post_data["input_path"] | |
output_path = post_data["output_path"] | |
options = json.loads(post_data["options"]) | |
# For plugins | |
extraInfo = {} | |
if "extraInfo" in post_data: | |
extraInfo = json.loads(post_data["extraInfo"]) | |
extraInfo["pluginsContext"] = json.loads(post_data["pluginsContext"]) | |
extraInfo["audio_options"] = options | |
extraInfo["input_path"] = input_path | |
extraInfo["output_path"] = output_path | |
extraInfo["ffmpeg"] = ffmpeg | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["output-audio"]["pre"], event="pre output-audio", data=extraInfo) | |
input_path = post_data["input_path"] | |
output_path = post_data["output_path"] | |
req_response = run_audio_post(PROD, None if isBatchMode else logger, input_path, output_path, options) | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["output-audio"]["post"], event="post output-audio", data=extraInfo) | |
if self.path == "/refreshPlugins": | |
logger.info("POST {}".format(self.path)) | |
status = plugin_manager.refresh_active_plugins() | |
logger.info("status") | |
logger.info(status) | |
req_response = ",".join(status) | |
if self.path == "/getWavV3StyleEmb": | |
logger.info("POST {}".format(self.path)) | |
wav_path = post_data["wav_path"] | |
models_manager.init_model("speaker_rep") | |
load_resp = models_manager.load_model("speaker_rep", f'{"./resources/app" if PROD else "."}/python/xvapitch/speaker_rep/speaker_rep.pt') | |
if load_resp=="ENOENT": | |
req_response = "ENOENT" | |
else: | |
style_emb = models_manager.models("speaker_rep").compute_embedding(wav_path).squeeze().cpu().detach().numpy() | |
req_response = ",".join([str(v) for v in style_emb]) | |
if self.path == "/computeEmbsAndDimReduction": | |
logger.info("POST {}".format(self.path)) | |
models_manager.init_model("speaker_rep") | |
load_resp = models_manager.load_model("speaker_rep", f'{"./resources/app" if PROD else "."}/python/xvapitch/speaker_rep/speaker_rep.pt') | |
embs = models_manager.models("speaker_rep").reduce_data_dimension(post_data["mappings"], post_data["includeAllVoices"], post_data["onlyInstalled"], post_data["algorithm"]) | |
req_response = embs | |
if self.path == "/checkReady": | |
modelsPaths = json.loads(post_data["modelsPaths"]) | |
device = torch.device("cpu") if post_data["device"]=="cpu" else (torch_dml_device if CPU_ONLY else torch.device("cuda:0")) | |
models_manager.set_device(device) | |
req_response = "ready" | |
if self.path == "/updateARPABet": | |
if "fastpitch1_1" in list(models_manager.models_bank.keys()): | |
models_manager.models_bank["fastpitch1_1"].refresh_arpabet_dicts() | |
if self.path == "/start_microphone_recording": | |
start_microphone_recording(logger, models_manager, f'{"./resources/app" if PROD else "."}') | |
req_response = "" | |
if self.path == "/move_recorded_file": | |
file_path = post_data["file_path"] | |
move_recorded_file(PROD, logger, models_manager, f'{"./resources/app" if PROD else "."}', file_path) | |
self._set_response() | |
self.wfile.write(req_response.encode("utf-8")) | |
except Exception as e: | |
with open("./DEBUG_request.txt", "w+") as f: | |
f.write(traceback.format_exc()) | |
f.write(str(post_data)) | |
logger.info("Post Error:\n {}".format(repr(e))) | |
print(traceback.format_exc()) | |
logger.info(traceback.format_exc()) | |
try: | |
# server = HTTPServer(("",8008), Handler) | |
server = ThreadedHTTPServer(("",8008), Handler) | |
# Prevent issues with socket reuse | |
server.allow_reuse_address = True | |
except: | |
with open("./DEBUG_server_error.txt", "w+") as f: | |
f.write(traceback.format_exc()) | |
logger.info(traceback.format_exc()) | |
try: | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["start"]["post"], event="post start", data=None) | |
print("Server ready") | |
logger.info("Server ready") | |
server.serve_forever() | |
except KeyboardInterrupt: | |
pass | |
server.server_close() | |