dnaveenr commited on
Commit
d7c3bb9
1 Parent(s): db9be24

add iSPICE files.

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -0
  2. app.py +47 -0
  3. ispice.py +190 -0
  4. requirements.txt +3 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+
4
+ RUN apt-get update
5
+ RUN mkdir -p /etc/apt/keyrings
6
+ RUN wget -O - https://packages.adoptium.net/artifactory/api/gpg/key/public | tee /etc/apt/keyrings/adoptium.asc
7
+ RUN echo "deb [signed-by=/etc/apt/keyrings/adoptium.asc] https://packages.adoptium.net/artifactory/deb $(awk -F= '/^VERSION_CODENAME/{print$2}' /etc/os-release) main" | tee /etc/apt/sources.list.d/adoptium.list
8
+ RUN apt-get update
9
+ RUN apt-get install -y temurin-8-jdk
10
+
11
+
12
+
13
+ WORKDIR /code
14
+
15
+ COPY ./requirements.txt /code/requirements.txt
16
+
17
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
18
+
19
+ COPY . .
20
+
21
+ CMD ["streamlit", "run", "app.py","--server.address", "0.0.0.0", "--server.port", "7860"]
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from ispice import Spice
3
+
4
+ # Function to compute score
5
+ def preprocess_captions(generated_captions, reference_captions):
6
+ hypotheses = {'image'+str(i): [generated_captions[i]] for i in range(len(generated_captions))}
7
+ references = {'image'+str(i): [reference_captions[i]] for i in range(len(reference_captions))}
8
+ return hypotheses, references
9
+
10
+ # Streamlit app
11
+ def main():
12
+ st.title("iSPICE Metric Evaluation")
13
+
14
+ # Dropdown for comparison option
15
+ mode = st.selectbox("Mode:", ["ID", "Name"])
16
+
17
+ spice_scorer = Spice(mode=mode)
18
+
19
+ # Description
20
+ st.write("You can either input single caption or multiple captions separated by new line.")
21
+ # Input text boxes
22
+ generated_caption = st.text_area("Generated Caption:", "")
23
+ reference_caption = st.text_area("Reference Caption:", "")
24
+
25
+ # Compute score button
26
+ if st.button("Compute Score"):
27
+ generated_captions = generated_caption.split("\n")
28
+ reference_captions = reference_caption.split("\n")
29
+
30
+ print(generated_captions, len(generated_captions))
31
+ print(reference_captions, len(reference_captions))
32
+
33
+
34
+ hypotheses, references = preprocess_captions(generated_captions, reference_captions)
35
+
36
+ if generated_caption.strip() == "" or reference_caption.strip() == "":
37
+ st.error("Please provide both generated and reference captions.")
38
+ else:
39
+ average_spice_score, spice_scores, average_ispice_score, ispice_scores = spice_scorer.compute_score(references, hypotheses)
40
+ st.subheader("Scores :")
41
+ st.write("Average SPICE Score:", average_spice_score)
42
+ st.write("Average iSPICE Score:", average_ispice_score)
43
+ st.write("SPICE Scores:", spice_scores)
44
+ st.write("iSPICE Scores:", ispice_scores)
45
+
46
+ if __name__ == "__main__":
47
+ main()
ispice.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import os
3
+ import sys
4
+ import subprocess
5
+ import threading
6
+ import json
7
+ import numpy as np
8
+ import ast
9
+ import tempfile
10
+
11
+ # Assumes spice.jar is in the same directory as spice.py. Change as needed.
12
+ SPICE_JAR = 'spice-1.0.jar'
13
+ TEMP_DIR = 'tmp'
14
+ CACHE_DIR = 'cache'
15
+
16
+ class Spice:
17
+ """
18
+ Main Class to compute the SPICE metric
19
+ """
20
+
21
+ def __init__(self, mode="ID"):
22
+ self.mode = mode
23
+
24
+ def float_convert(self, obj):
25
+ try:
26
+ return float(obj)
27
+ except:
28
+ return np.nan
29
+
30
+ def fetch_tuples(self, tuples):
31
+ result_tuples = []
32
+ for item in tuples:
33
+ result_tuples.append(item['tuple'])
34
+ return result_tuples
35
+
36
+ def find_common(self, tuple_A, tuple_B):
37
+ common = 0
38
+ for item in tuple_A:
39
+ if item in tuple_B:
40
+ common += 1
41
+
42
+ return common
43
+
44
+ def get_identity_tuples(self, data):
45
+ person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"]
46
+ filtered_tuples = [item for item in data if any(person_id in item for person_id in person_ids)]
47
+ action_tuples = [tup for tup in filtered_tuples if len(tup) > 1]
48
+ id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1]))
49
+ id_tuples = [list(tup) for tup in id_tuples]
50
+ return action_tuples, id_tuples
51
+
52
+ def get_named_tuples(self, data):
53
+ names_list = ["ray", "sam", "casey", "riley", "morgan", "alex", "quinn", "cameron", "avery", "charlie", "jamie", "mike"]
54
+ filtered_tuples = [item for item in data if any(name in item for name in names_list)]
55
+ action_tuples = [tup for tup in filtered_tuples if len(tup) > 1]
56
+ id_tuples = list(set([tuple(tup) for tup in filtered_tuples if len(tup) == 1]))
57
+ id_tuples = [list(tup) for tup in id_tuples]
58
+ return action_tuples, id_tuples
59
+
60
+ def calculate_metrics(self, pred_tuples, ref_tuples):
61
+ print(f"pred_tuples : {pred_tuples}")
62
+ print(f"ref_tuples : {ref_tuples}")
63
+ common = self.find_common(pred_tuples, ref_tuples)
64
+ print(f"Common : {common}")
65
+ total_pred = len(pred_tuples)
66
+ print(f"total_pred : {total_pred}")
67
+ total_ref = len(ref_tuples)
68
+ print(f"total_ref : {total_ref}")
69
+ if total_pred == 0 or total_ref == 0:
70
+ return 0
71
+ #print(f"Common : {common}, Total Pred : {total_pred}, Total Ref: {total_ref}")
72
+ precision = common / total_pred
73
+ recall = common / total_ref
74
+
75
+ print(f"Precision : {precision}, Recall: {recall}")
76
+
77
+ if precision + recall == 0:
78
+ return 0
79
+
80
+ f1_score = (2 * precision * recall)/(precision + recall)
81
+ #print(f"precision : {precision}")
82
+ #print(f"recall : {recall}")
83
+ #print(f"f-score: {f1_score}")
84
+
85
+ return f1_score
86
+
87
+ # def get_log_penalty(gt,pred):
88
+ # person_ids = ["p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11"]
89
+ # gt_set = set()
90
+ # pred_set = set()
91
+
92
+ # for word in pred.split():
93
+ # if word.lower() in person_ids:
94
+
95
+
96
+
97
+
98
+
99
+
100
+ def compute_score(self, gts, res):
101
+ assert(sorted(gts.keys()) == sorted(res.keys()))
102
+ imgIds = sorted(gts.keys())
103
+
104
+ # Prepare temp input file for the SPICE scorer
105
+ input_data = []
106
+ for id in imgIds:
107
+ hypo = res[id]
108
+ ref = gts[id]
109
+
110
+ # Sanity check.
111
+ assert(type(hypo) is list)
112
+ assert(len(hypo) == 1)
113
+ assert(type(ref) is list)
114
+ assert(len(ref) >= 1)
115
+
116
+ input_data.append({
117
+ "image_id" : id,
118
+ "test" : hypo[0],
119
+ "refs" : ref
120
+ })
121
+
122
+ cwd = os.path.dirname(os.path.abspath(__file__))
123
+ temp_dir=os.path.join(cwd, TEMP_DIR)
124
+ if not os.path.exists(temp_dir):
125
+ os.makedirs(temp_dir)
126
+ in_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir,
127
+ mode='w+')
128
+ json.dump(input_data, in_file, indent=2)
129
+ in_file.close()
130
+
131
+ # Start job
132
+ out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
133
+ out_file.close()
134
+ cache_dir=os.path.join(cwd, CACHE_DIR)
135
+ if not os.path.exists(cache_dir):
136
+ os.makedirs(cache_dir)
137
+ spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name,
138
+ '-cache', cache_dir,
139
+ '-out', out_file.name,
140
+ '-detailed',
141
+ '-silent'
142
+ ]
143
+ subprocess.check_call(spice_cmd,
144
+ cwd=os.path.dirname(os.path.abspath(__file__)))
145
+
146
+ # Read and process results
147
+ with open(out_file.name) as data_file:
148
+ results = json.load(data_file)
149
+ os.remove(in_file.name)
150
+ os.remove(out_file.name)
151
+
152
+
153
+ imgId_to_scores = {}
154
+ spice_scores = []
155
+ ispice_scores = []
156
+ for item in results:
157
+ imgId_to_scores[item['image_id']] = item['scores']
158
+ spice_scores.append(self.float_convert(item['scores']['All']['f']))
159
+ pred_tuples = self.fetch_tuples(item['test_tuples'])
160
+ ref_tuples = self.fetch_tuples(item['ref_tuples'])
161
+ if(self.mode == "ID"):
162
+ ia_pred_tuples, id_pred_tuples = self.get_identity_tuples(pred_tuples)
163
+ ia_ref_tuples, id_ref_tuples = self.get_identity_tuples(ref_tuples)
164
+ elif(self.mode == "Name"):
165
+ ia_pred_tuples, id_pred_tuples = self.get_named_tuples(pred_tuples)
166
+ ia_ref_tuples, id_ref_tuples = self.get_named_tuples(ref_tuples)
167
+
168
+
169
+ if(len(ia_pred_tuples) != 0):
170
+ i_spice_score = self.calculate_metrics(ia_pred_tuples, ia_ref_tuples)
171
+ i_spice_score *= self.calculate_metrics(id_pred_tuples, id_ref_tuples)
172
+ ispice_scores.append(i_spice_score)
173
+
174
+ average_spice_score = np.mean(np.array(spice_scores))
175
+ average_ispice_score = np.mean(np.array(ispice_scores))
176
+
177
+ return average_spice_score, spice_scores, average_ispice_score, ispice_scores
178
+
179
+ def method(self):
180
+ return "iSPICE"
181
+
182
+
183
+
184
+ #test = Spice()
185
+ #test_query = {"image1":["p1 faces him. p1 shrugs. p2 shrugs. p1 gives a faint nod."],
186
+ # "image2":["two fedex trucks parked on the side of the street."]}
187
+ #test_ref = {"image1":["p1 faces him. p1 tosses down her phone. p2 considers the idea. p1 frowns."],
188
+ # "image2":["two fedex trucks parked on a side of a street with tall buidings behind them."]}
189
+
190
+ #print(test.compute_score(test_ref, test_query))
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ requests==2.27.*
3
+ streamlit