File size: 2,599 Bytes
113dbd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import sys
import logging
from flask import Flask, request, jsonify
from flask_cors import CORS
from serve import get_model_api
# define the app
app = Flask(__name__)
CORS(app) # needed for cross-domain requests, allow everything by default
# logging for heroku
if 'DYNO' in os.environ:
app.logger.addHandler(logging.StreamHandler(sys.stdout))
app.logger.setLevel(logging.INFO)
app.logger.addHandler(logging.StreamHandler(sys.stdout))
app.logger.setLevel(logging.INFO)
# load the model
model_api = get_model_api()
# API route
@app.route('/api', methods=['POST'])
def api():
"""API function
All model-specific logic to be defined in the get_model_api()
function
"""
input_data = request.json
log = open("test_topic_serve_log.csv", 'a', encoding='utf-8')
app.logger.info("api_input: " + str(input_data))
log.write("api_input: " + str(input_data))
# input_title_str = input_data['title']
# input_domain_str = input_data['domain']
input_sys_prompt_str = input_data['system_prompt']
input_USER_str = input_data['USER']
# input_ASSISTANT_str = input_data['ASSISTANT']
input_history_str = input_data['history']
# output_data = model_api(input_title_str, input_domain_str)
output_data = model_api(input_sys_prompt_str, input_history_str, input_USER_str)
app.logger.info("api_output: " + str(output_data))
response = jsonify(output_data)
log.write("api_output: " + str(output_data) + "\n")
return response
# API2 route
@app.route('/labelapi', methods=['POST'])
def labelapi():
"""label API function
record user label action
All model-specific logic to be defined in the get_model_api()
function
"""
input_data = request.json
log = open("test_topic_label_log.csv", 'a', encoding='utf-8')
app.logger.info("api_input: " + str(input_data))
log.write("api_input: " + str(input_data)+ "\n")
output_data = {"input": input_data, "output": "record_success"}
response = output_data
return response
@app.route('/')
def index():
return "Index API"
# HTTP Errors handlers
@app.errorhandler(404)
def url_error(e):
return """
Wrong URL!
<pre>{}</pre>""".format(e), 404
@app.errorhandler(500)
def server_error(e):
return """
An internal error occurred: <pre>{}</pre>
See logs for full stacktrace.
""".format(e), 500
if __name__ == '__main__':
# This is used when running locally.
app.run(host='0.0.0.0',port=4455,debug=True)
|