Spaces:
Running
Running
from __future__ import division | |
import os | |
import sys | |
import subprocess | |
import threading | |
import json | |
import numpy as np | |
import ast | |
import tempfile | |
# Assumes spice.jar is in the same directory as spice.py. Change as needed. | |
SPICE_JAR = 'spice-1.0.jar' | |
TEMP_DIR = 'tmp' | |
CACHE_DIR = 'cache' | |
class Spice: | |
""" | |
Main Class to compute the SPICE metric | |
""" | |
def __init__(self, mode="ID"): | |
self.mode = mode | |
def float_convert(self, obj): | |
try: | |
return float(obj) | |
except: | |
return np.nan | |
def fetch_tuples(self, tuples): | |
result_tuples = [] | |
for item in tuples: | |
result_tuples.append(item['tuple']) | |
return result_tuples | |
def find_common(self, tuple_A, tuple_B): | |
common = 0 | |
for item in tuple_A: | |
if item in tuple_B: | |
common += 1 | |
return common | |
def get_identity_tuples(self, data): | |
person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"] | |
filtered_tuples = [item for item in data if any(person_id in item for person_id in person_ids)] | |
action_tuples = [tup for tup in filtered_tuples if len(tup) > 1] | |
id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1])) | |
id_tuples = [list(tup) for tup in id_tuples] | |
return action_tuples, id_tuples | |
def get_named_tuples(self, data): | |
names_list = ["ray", "sam", "casey", "riley", "morgan", "alex", "quinn", "cameron", "avery", "charlie", "jamie", "mike"] | |
filtered_tuples = [item for item in data if any(name in item for name in names_list)] | |
action_tuples = [tup for tup in filtered_tuples if len(tup) > 1] | |
id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1])) | |
id_tuples = [list(tup) for tup in id_tuples] | |
return action_tuples, id_tuples | |
def calculate_metrics(self, pred_tuples, ref_tuples): | |
print(f"pred_tuples : {pred_tuples}") | |
print(f"ref_tuples : {ref_tuples}") | |
common = self.find_common(pred_tuples, ref_tuples) | |
print(f"Common : {common}") | |
total_pred = len(pred_tuples) | |
print(f"total_pred : {total_pred}") | |
total_ref = len(ref_tuples) | |
print(f"total_ref : {total_ref}") | |
if total_pred == 0 or total_ref == 0: | |
return 0 | |
#print(f"Common : {common}, Total Pred : {total_pred}, Total Ref: {total_ref}") | |
precision = common / total_pred | |
recall = common / total_ref | |
print(f"Precision : {precision}, Recall: {recall}") | |
if precision + recall == 0: | |
return 0 | |
f1_score = (2 * precision * recall)/(precision + recall) | |
#print(f"precision : {precision}") | |
#print(f"recall : {recall}") | |
#print(f"f-score: {f1_score}") | |
return f1_score | |
# def get_log_penalty(gt,pred): | |
# person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"] | |
# gt_set = set() | |
# pred_set = set() | |
# for word in pred.split(): | |
# if word.lower() in person_ids: | |
def compute_score(self, gts, res): | |
assert(sorted(gts.keys()) == sorted(res.keys())) | |
imgIds = sorted(gts.keys()) | |
# Prepare temp input file for the SPICE scorer | |
input_data = [] | |
for id in imgIds: | |
hypo = res[id] | |
ref = gts[id] | |
# Sanity check. | |
assert(type(hypo) is list) | |
assert(len(hypo) == 1) | |
assert(type(ref) is list) | |
assert(len(ref) >= 1) | |
input_data.append({ | |
"image_id" : id, | |
"test" : hypo[0], | |
"refs" : ref | |
}) | |
cwd = os.path.dirname(os.path.abspath(__file__)) | |
temp_dir=os.path.join(cwd, TEMP_DIR) | |
if not os.path.exists(temp_dir): | |
os.makedirs(temp_dir) | |
in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir, | |
mode='w+') | |
json.dump(input_data, in_file, indent=2) | |
in_file.close() | |
# Start job | |
out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) | |
out_file.close() | |
cache_dir=os.path.join(cwd, CACHE_DIR) | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir) | |
spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, | |
'-cache', cache_dir, | |
'-out', out_file.name, | |
'-detailed', | |
'-silent' | |
] | |
subprocess.check_call(spice_cmd, | |
cwd=os.path.dirname(os.path.abspath(__file__))) | |
# Read and process results | |
with open(out_file.name) as data_file: | |
results = json.load(data_file) | |
os.remove(in_file.name) | |
os.remove(out_file.name) | |
imgId_to_scores = {} | |
spice_scores = [] | |
ispice_scores = [] | |
for item in results: | |
imgId_to_scores[item['image_id']] = item['scores'] | |
spice_scores.append(self.float_convert(item['scores']['All']['f'])) | |
pred_tuples = self.fetch_tuples(item['test_tuples']) | |
ref_tuples = self.fetch_tuples(item['ref_tuples']) | |
if(self.mode == "ID"): | |
ia_pred_tuples, id_pred_tuples = self.get_identity_tuples(pred_tuples) | |
ia_ref_tuples, id_ref_tuples = self.get_identity_tuples(ref_tuples) | |
elif(self.mode == "Name"): | |
ia_pred_tuples, id_pred_tuples = self.get_named_tuples(pred_tuples) | |
ia_ref_tuples, id_ref_tuples = self.get_named_tuples(ref_tuples) | |
if(len(ia_pred_tuples) != 0): | |
i_spice_score = self.calculate_metrics(ia_pred_tuples, ia_ref_tuples) | |
i_spice_score *= self.calculate_metrics(id_pred_tuples, id_ref_tuples) | |
ispice_scores.append(i_spice_score) | |
average_spice_score = np.mean(np.array(spice_scores)) | |
average_ispice_score = np.mean(np.array(ispice_scores)) | |
return average_spice_score, spice_scores, average_ispice_score, ispice_scores | |
def method(self): | |
return "iSPICE" | |
#test = Spice() | |
#test_query = {"image1":["p1 faces him. p1 shrugs. p2 shrugs. p1 gives a faint nod."], | |
# "image2":["two fedex trucks parked on the side of the street."]} | |
#test_ref = {"image1":["p1 faces him. p1 tosses down her phone. p2 considers the idea. p1 frowns."], | |
# "image2":["two fedex trucks parked on a side of a street with tall buidings behind them."]} | |
#print(test.compute_score(test_ref, test_query)) |