Spaces:
Running
Running
''' | |
sudo docker run --gpus all --runtime=nvidia --rm \ | |
-v /home/ubuntu/dotdemo/third_party:/third_party \ | |
-v /home/ubuntu/dotdemo-dev:/dotdemo-dev \ | |
-v /home/ubuntu/dot-demo-assets/ml-logs:/logs \ | |
-v /home/ubuntu/dotdemo/train_server:/app \ | |
--network="host" \ | |
--shm-size 1G \ | |
-it fantasyfish677/rvc:v0 /bin/bash | |
pip3 install flask_cors | |
python3 /app/server.py 2>&1 | tee /logs/train_server.log | |
export FLASK_APP=server | |
export FLASK_DEBUG=true | |
pip3 install gunicorn | |
gunicorn -b :8080 --timeout=600 server:app | |
curl -X GET http://3.16.130.199:8080/ping | |
curl -X POST http://3.16.130.199:8080/train \ | |
-H 'Content-Type: application/json' \ | |
-d '{"expName":"varun124","trainsetDir":"varun124"}' | |
curl -X GET http://3.16.130.199:8080/check \ | |
-H 'Content-Type: application/json' \ | |
-d '{"expName":"kanye-1"}' | |
''' | |
import json | |
import os | |
from flask import Flask, request | |
from logging import exception | |
import time | |
from server_utils import train_model | |
from flask_cors import CORS, cross_origin | |
print("import successful!") | |
app = Flask("train server") | |
cors = CORS(app) | |
app.config['CORS_HEADERS'] = 'Content-Type' | |
def healthcheck(): | |
return json.dumps({"code": 200, "message": "responding"}).encode('utf-8') | |
def train(): | |
if request.headers['Content-Type'] != 'application/json': | |
exception("Header error") | |
return json.dumps({"message":"Header error"}), 500 | |
try: | |
content = request.get_json() | |
exp_name = content['expName'] | |
trainset_dir = os.path.join('/dotdemo-dev', content['trainsetDir']) | |
log_path = os.path.join("/logs{}.log".format(exp_name)) | |
if os.path.exists('/third_party/RVC/logs/{}'.format(exp_name)): | |
os.system('rm -rf /third_party/RVC/logs/{}'.format(exp_name)) | |
if not os.path.exists(trainset_dir): | |
exception("Training set doesn't exist") | |
return json.dumps({"message":"Training set doesn't exist"}), 404 | |
start_time = time.time() | |
train_model(exp_name, trainset_dir, log_path, total_epoch=20) | |
end_time = time.time() | |
return json.dumps({"message": "Training Completed in {} secs.".format(end_time - start_time)}), 200 | |
except Exception as e: | |
exception("Training process failed") | |
return json.dumps({"message":"Training process failed due to {}".format(e)}), 500 | |
def check(): | |
if request.headers['Content-Type'] != 'application/json': | |
exception("Header error") | |
return json.dumps({"message":"Header error"}), 500 | |
content = request.get_json() | |
exp_name = content['expName'] | |
if os.path.exists('/third_party/RVC/weights/{}.pth'.format(exp_name)): | |
return json.dumps({"message": "Model found."}), 200 | |
else: | |
return json.dumps({"message": "Model not found."}), 200 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8080, debug=True) |