Spaces:
Sleeping
Sleeping
File size: 6,666 Bytes
d7c3bb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from __future__ import division
import os
import sys
import subprocess
import threading
import json
import numpy as np
import ast
import tempfile
# Assumes spice.jar is in the same directory as spice.py. Change as needed.
SPICE_JAR = 'spice-1.0.jar'
TEMP_DIR = 'tmp'
CACHE_DIR = 'cache'
class Spice:
"""
Main Class to compute the SPICE metric
"""
def __init__(self, mode="ID"):
self.mode = mode
def float_convert(self, obj):
try:
return float(obj)
except:
return np.nan
def fetch_tuples(self, tuples):
result_tuples = []
for item in tuples:
result_tuples.append(item['tuple'])
return result_tuples
def find_common(self, tuple_A, tuple_B):
common = 0
for item in tuple_A:
if item in tuple_B:
common += 1
return common
def get_identity_tuples(self, data):
person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"]
filtered_tuples = [item for item in data if any(person_id in item for person_id in person_ids)]
action_tuples = [tup for tup in filtered_tuples if len(tup) > 1]
id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1]))
id_tuples = [list(tup) for tup in id_tuples]
return action_tuples, id_tuples
def get_named_tuples(self, data):
names_list = ["ray", "sam", "casey", "riley", "morgan", "alex", "quinn", "cameron", "avery", "charlie", "jamie", "mike"]
filtered_tuples = [item for item in data if any(name in item for name in names_list)]
action_tuples = [tup for tup in filtered_tuples if len(tup) > 1]
id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1]))
id_tuples = [list(tup) for tup in id_tuples]
return action_tuples, id_tuples
def calculate_metrics(self, pred_tuples, ref_tuples):
print(f"pred_tuples : {pred_tuples}")
print(f"ref_tuples : {ref_tuples}")
common = self.find_common(pred_tuples, ref_tuples)
print(f"Common : {common}")
total_pred = len(pred_tuples)
print(f"total_pred : {total_pred}")
total_ref = len(ref_tuples)
print(f"total_ref : {total_ref}")
if total_pred == 0 or total_ref == 0:
return 0
#print(f"Common : {common}, Total Pred : {total_pred}, Total Ref: {total_ref}")
precision = common / total_pred
recall = common / total_ref
print(f"Precision : {precision}, Recall: {recall}")
if precision + recall == 0:
return 0
f1_score = (2 * precision * recall)/(precision + recall)
#print(f"precision : {precision}")
#print(f"recall : {recall}")
#print(f"f-score: {f1_score}")
return f1_score
# def get_log_penalty(gt,pred):
# person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"]
# gt_set = set()
# pred_set = set()
# for word in pred.split():
# if word.lower() in person_ids:
def compute_score(self, gts, res):
assert(sorted(gts.keys()) == sorted(res.keys()))
imgIds = sorted(gts.keys())
# Prepare temp input file for the SPICE scorer
input_data = []
for id in imgIds:
hypo = res[id]
ref = gts[id]
# Sanity check.
assert(type(hypo) is list)
assert(len(hypo) == 1)
assert(type(ref) is list)
assert(len(ref) >= 1)
input_data.append({
"image_id" : id,
"test" : hypo[0],
"refs" : ref
})
cwd = os.path.dirname(os.path.abspath(__file__))
temp_dir=os.path.join(cwd, TEMP_DIR)
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir,
mode='w+')
json.dump(input_data, in_file, indent=2)
in_file.close()
# Start job
out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
out_file.close()
cache_dir=os.path.join(cwd, CACHE_DIR)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name,
'-cache', cache_dir,
'-out', out_file.name,
'-detailed',
'-silent'
]
subprocess.check_call(spice_cmd,
cwd=os.path.dirname(os.path.abspath(__file__)))
# Read and process results
with open(out_file.name) as data_file:
results = json.load(data_file)
os.remove(in_file.name)
os.remove(out_file.name)
imgId_to_scores = {}
spice_scores = []
ispice_scores = []
for item in results:
imgId_to_scores[item['image_id']] = item['scores']
spice_scores.append(self.float_convert(item['scores']['All']['f']))
pred_tuples = self.fetch_tuples(item['test_tuples'])
ref_tuples = self.fetch_tuples(item['ref_tuples'])
if(self.mode == "ID"):
ia_pred_tuples, id_pred_tuples = self.get_identity_tuples(pred_tuples)
ia_ref_tuples, id_ref_tuples = self.get_identity_tuples(ref_tuples)
elif(self.mode == "Name"):
ia_pred_tuples, id_pred_tuples = self.get_named_tuples(pred_tuples)
ia_ref_tuples, id_ref_tuples = self.get_named_tuples(ref_tuples)
if(len(ia_pred_tuples) != 0):
i_spice_score = self.calculate_metrics(ia_pred_tuples, ia_ref_tuples)
i_spice_score *= self.calculate_metrics(id_pred_tuples, id_ref_tuples)
ispice_scores.append(i_spice_score)
average_spice_score = np.mean(np.array(spice_scores))
average_ispice_score = np.mean(np.array(ispice_scores))
return average_spice_score, spice_scores, average_ispice_score, ispice_scores
def method(self):
return "iSPICE"
#test = Spice()
#test_query = {"image1":["p1 faces him. p1 shrugs. p2 shrugs. p1 gives a faint nod."],
# "image2":["two fedex trucks parked on the side of the street."]}
#test_ref = {"image1":["p1 faces him. p1 tosses down her phone. p2 considers the idea. p1 frowns."],
# "image2":["two fedex trucks parked on a side of a street with tall buidings behind them."]}
#print(test.compute_score(test_ref, test_query)) |