dotdemo / train_server /server.py
fantasyfish's picture
update two server.py files
239a35e
'''
sudo docker run --gpus all --runtime=nvidia --rm \
-v /home/ubuntu/dotdemo/third_party:/third_party \
-v /home/ubuntu/dotdemo-dev:/dotdemo-dev \
-v /home/ubuntu/dot-demo-assets/ml-logs:/logs \
-v /home/ubuntu/dotdemo/train_server:/app \
--network="host" \
--shm-size 1G \
-it fantasyfish677/rvc:v0 /bin/bash
pip3 install flask_cors
python3 /app/server.py 2>&1 | tee /logs/train_server.log
export FLASK_APP=server
export FLASK_DEBUG=true
pip3 install gunicorn
gunicorn -b :8080 --timeout=600 server:app
curl -X GET http://3.16.130.199:8080/ping
curl -X POST http://3.16.130.199:8080/train \
-H 'Content-Type: application/json' \
-d '{"expName":"varun124","trainsetDir":"varun124"}'
curl -X GET http://3.16.130.199:8080/check \
-H 'Content-Type: application/json' \
-d '{"expName":"kanye-1"}'
'''
import json
import os
from flask import Flask, request
from logging import exception
import time
from server_utils import train_model
from flask_cors import CORS, cross_origin
print("import successful!")
app = Flask("train server")
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'
@app.route("/ping", methods=['GET', 'POST'])
@cross_origin()
def healthcheck():
return json.dumps({"code": 200, "message": "responding"}).encode('utf-8')
@app.route("/train", methods=['GET', 'POST'])
@cross_origin()
def train():
if request.headers['Content-Type'] != 'application/json':
exception("Header error")
return json.dumps({"message":"Header error"}), 500
try:
content = request.get_json()
exp_name = content['expName']
trainset_dir = os.path.join('/dotdemo-dev', content['trainsetDir'])
log_path = os.path.join("/logs{}.log".format(exp_name))
if os.path.exists('/third_party/RVC/logs/{}'.format(exp_name)):
os.system('rm -rf /third_party/RVC/logs/{}'.format(exp_name))
if not os.path.exists(trainset_dir):
exception("Training set doesn't exist")
return json.dumps({"message":"Training set doesn't exist"}), 404
start_time = time.time()
train_model(exp_name, trainset_dir, log_path, total_epoch=20)
end_time = time.time()
return json.dumps({"message": "Training Completed in {} secs.".format(end_time - start_time)}), 200
except Exception as e:
exception("Training process failed")
return json.dumps({"message":"Training process failed due to {}".format(e)}), 500
@app.route("/check", methods=['GET', 'POST'])
@cross_origin()
def check():
if request.headers['Content-Type'] != 'application/json':
exception("Header error")
return json.dumps({"message":"Header error"}), 500
content = request.get_json()
exp_name = content['expName']
if os.path.exists('/third_party/RVC/weights/{}.pth'.format(exp_name)):
return json.dumps({"message": "Model found."}), 200
else:
return json.dumps({"message": "Model not found."}), 200
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8080, debug=True)