Spaces:
Runtime error
Runtime error
import json | |
import werkzeug | |
import tensorflow as tf | |
from config import config, parseArgs, configPDF | |
from extract_feature import get_img_feat, build_model | |
from main import setSession, loadWeights, setSavers | |
from model import MACnet | |
from preprocess import Preprocesser | |
import warnings | |
def predict(image, question): | |
parseArgs() | |
configPDF() | |
with open(config.configFile(), "a+") as outFile: | |
json.dump(vars(config), outFile) | |
if config.gpus != "": | |
config.gpusNum = len(config.gpus.split(",")) | |
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus | |
tf.reset_default_graph() | |
tf.Graph().as_default() | |
tf.logging.set_verbosity(tf.logging.ERROR) | |
cnn_model = build_model() | |
imageData = get_img_feat(cnn_model, image) | |
preprocessor = Preprocesser() | |
qData, embeddings, answerDict = preprocessor.preprocessData(question) | |
model = MACnet(embeddings, answerDict) | |
init = tf.global_variables_initializer() | |
savers = setSavers(model) | |
saver, emaSaver = savers["saver"], savers["emaSaver"] | |
sessionConfig = setSession() | |
data = {'data': qData, 'image': imageData} | |
with tf.Session(config=sessionConfig) as sess: | |
sess.graph.finalize() | |
epoch = loadWeights(sess, saver, init) | |
emaSaver.restore(sess, config.weightsFile(epoch)) | |
evalRes = model.runBatch(sess, data['data'], data['image'], False) | |
answer = None | |
if evalRes in ['top', 'bottom']: | |
answer = 'The caption at the %s side of the object.' % evalRes | |
elif evalRes in ['True', 'False']: | |
answer = 'There is at least one title object in this image.' | |
else: | |
answer = 'This image contain %s specific object(s).' % evalRes | |
return answer |