ydin0771 commited on
Commit
5263599
β€’
1 Parent(s): 7951498

Upload demo.py

Browse files
Files changed (1) hide show
  1. demo.py +56 -0
demo.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+
4
+ import werkzeug
5
+ import tensorflow as tf
6
+
7
+ from config import config, parseArgs, configPDF
8
+ from extract_feature import get_img_feat, build_model
9
+ from main import setSession, loadWeights, setSavers
10
+ from model import MACnet
11
+ from preprocess import Preprocesser
12
+ import warnings
13
+
14
+ def predict(image, question):
15
+ parseArgs()
16
+ configPDF()
17
+ with open(config.configFile(), "a+") as outFile:
18
+ json.dump(vars(config), outFile)
19
+
20
+ if config.gpus != "":
21
+ config.gpusNum = len(config.gpus.split(","))
22
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus
23
+ tf.reset_default_graph()
24
+ tf.Graph().as_default()
25
+ tf.logging.set_verbosity(tf.logging.ERROR)
26
+ cnn_model = build_model()
27
+ imageData = get_img_feat(cnn_model, image)
28
+
29
+ preprocessor = Preprocesser()
30
+ qData, embeddings, answerDict = preprocessor.preprocessData(question)
31
+ model = MACnet(embeddings, answerDict)
32
+ init = tf.global_variables_initializer()
33
+
34
+ savers = setSavers(model)
35
+ saver, emaSaver = savers["saver"], savers["emaSaver"]
36
+ sessionConfig = setSession()
37
+
38
+ data = {'data': qData, 'image': imageData}
39
+
40
+ with tf.Session(config=sessionConfig) as sess:
41
+ sess.graph.finalize()
42
+
43
+ epoch = loadWeights(sess, saver, init)
44
+ emaSaver.restore(sess, config.weightsFile(epoch))
45
+
46
+ evalRes = model.runBatch(sess, data['data'], data['image'], False)
47
+ answer = None
48
+
49
+ if evalRes in ['top', 'bottom']:
50
+ answer = 'The caption at the %s side of the object.' % evalRes
51
+ elif evalRes in ['True', 'False']:
52
+ answer = 'There is at least one title object in this image.'
53
+ else:
54
+ answer = 'This image contain %s specific object(s).' % evalRes
55
+
56
+ return answer