from __future__ import division import os import sys import subprocess import threading import json import numpy as np import ast import tempfile # Assumes spice.jar is in the same directory as spice.py. Change as needed. SPICE_JAR = 'spice-1.0.jar' TEMP_DIR = 'tmp' CACHE_DIR = 'cache' class Spice: """ Main Class to compute the SPICE metric """ def __init__(self, mode="ID"): self.mode = mode def float_convert(self, obj): try: return float(obj) except: return np.nan def fetch_tuples(self, tuples): result_tuples = [] for item in tuples: result_tuples.append(item['tuple']) return result_tuples def find_common(self, tuple_A, tuple_B): common = 0 for item in tuple_A: if item in tuple_B: common += 1 return common def get_identity_tuples(self, data): person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"] filtered_tuples = [item for item in data if any(person_id in item for person_id in person_ids)] action_tuples = [tup for tup in filtered_tuples if len(tup) > 1] id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1])) id_tuples = [list(tup) for tup in id_tuples] return action_tuples, id_tuples def get_named_tuples(self, data): names_list = ["ray", "sam", "casey", "riley", "morgan", "alex", "quinn", "cameron", "avery", "charlie", "jamie", "mike"] filtered_tuples = [item for item in data if any(name in item for name in names_list)] action_tuples = [tup for tup in filtered_tuples if len(tup) > 1] id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1])) id_tuples = [list(tup) for tup in id_tuples] return action_tuples, id_tuples def calculate_metrics(self, pred_tuples, ref_tuples): print(f"pred_tuples : {pred_tuples}") print(f"ref_tuples : {ref_tuples}") common = self.find_common(pred_tuples, ref_tuples) print(f"Common : {common}") total_pred = len(pred_tuples) print(f"total_pred : {total_pred}") total_ref = len(ref_tuples) print(f"total_ref : {total_ref}") if total_pred == 0 or total_ref == 0: return 0 #print(f"Common : {common}, Total Pred : {total_pred}, Total Ref: {total_ref}") precision = common / total_pred recall = common / total_ref print(f"Precision : {precision}, Recall: {recall}") if precision + recall == 0: return 0 f1_score = (2 * precision * recall)/(precision + recall) #print(f"precision : {precision}") #print(f"recall : {recall}") #print(f"f-score: {f1_score}") return f1_score # def get_log_penalty(gt,pred): # person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"] # gt_set = set() # pred_set = set() # for word in pred.split(): # if word.lower() in person_ids: def compute_score(self, gts, res): assert(sorted(gts.keys()) == sorted(res.keys())) imgIds = sorted(gts.keys()) # Prepare temp input file for the SPICE scorer input_data = [] for id in imgIds: hypo = res[id] ref = gts[id] # Sanity check. assert(type(hypo) is list) assert(len(hypo) == 1) assert(type(ref) is list) assert(len(ref) >= 1) input_data.append({ "image_id" : id, "test" : hypo[0], "refs" : ref }) cwd = os.path.dirname(os.path.abspath(__file__)) temp_dir=os.path.join(cwd, TEMP_DIR) if not os.path.exists(temp_dir): os.makedirs(temp_dir) in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir, mode='w+') json.dump(input_data, in_file, indent=2) in_file.close() # Start job out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) out_file.close() cache_dir=os.path.join(cwd, CACHE_DIR) if not os.path.exists(cache_dir): os.makedirs(cache_dir) spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, '-cache', cache_dir, '-out', out_file.name, '-detailed', '-silent' ] subprocess.check_call(spice_cmd, cwd=os.path.dirname(os.path.abspath(__file__))) # Read and process results with open(out_file.name) as data_file: results = json.load(data_file) os.remove(in_file.name) os.remove(out_file.name) imgId_to_scores = {} spice_scores = [] ispice_scores = [] for item in results: imgId_to_scores[item['image_id']] = item['scores'] spice_scores.append(self.float_convert(item['scores']['All']['f'])) pred_tuples = self.fetch_tuples(item['test_tuples']) ref_tuples = self.fetch_tuples(item['ref_tuples']) if(self.mode == "ID"): ia_pred_tuples, id_pred_tuples = self.get_identity_tuples(pred_tuples) ia_ref_tuples, id_ref_tuples = self.get_identity_tuples(ref_tuples) elif(self.mode == "Name"): ia_pred_tuples, id_pred_tuples = self.get_named_tuples(pred_tuples) ia_ref_tuples, id_ref_tuples = self.get_named_tuples(ref_tuples) if(len(ia_pred_tuples) != 0): i_spice_score = self.calculate_metrics(ia_pred_tuples, ia_ref_tuples) i_spice_score *= self.calculate_metrics(id_pred_tuples, id_ref_tuples) ispice_scores.append(i_spice_score) average_spice_score = np.mean(np.array(spice_scores)) average_ispice_score = np.mean(np.array(ispice_scores)) return average_spice_score, spice_scores, average_ispice_score, ispice_scores def method(self): return "iSPICE" #test = Spice() #test_query = {"image1":["p1 faces him. p1 shrugs. p2 shrugs. p1 gives a faint nod."], # "image2":["two fedex trucks parked on the side of the street."]} #test_ref = {"image1":["p1 faces him. p1 tosses down her phone. p2 considers the idea. p1 frowns."], # "image2":["two fedex trucks parked on a side of a street with tall buidings behind them."]} #print(test.compute_score(test_ref, test_query))