abhitopia
bug fix
d355235
import json
import logging
from qa_generator_pipeline import QAGeneratorPipeline
logger = logging.getLogger(__name__)
JSON_CONTENT_TYPE = 'application/json'
def model_fn(model_dir):
logging.info('[### model_fn ###] Loading model from {}'.format(model_dir))
model = QAGeneratorPipeline(model_dir=model_dir, use_cuda=True)
return model
def predict_fn(input_data, model):
logging.info('[### predict_fn ###] Entering predict_fn() method')
logger.info("input text: {}".format(input_data))
prediction = model(input_data)
logger.info("prediction: {}".format(input_data))
return prediction
def input_fn(serialized_input_data, content_type=JSON_CONTENT_TYPE):
logging.info('[### input_fn ###] Entering input_fn() method')
logging.info('[### input_fn ###] request_content_type: {}'.format(content_type))
logging.info('[### input_fn ###] request_body: {}'.format(type(serialized_input_data)))
if content_type == JSON_CONTENT_TYPE:
input_data = json.loads(serialized_input_data)
return input_data
else:
pass
def output_fn(prediction_output, accept=JSON_CONTENT_TYPE):
logging.info('[### output_fn ###] Entering output_fn() method')
logging.info('[### output_fn ###] prediction: {}'.format(prediction_output))
if accept == JSON_CONTENT_TYPE:
return json.dumps(prediction_output), accept
raise Exception('Unsupported Content Type')