Macrodove commited on
Commit
1cbd56e
1 Parent(s): 3f9cb68

Draft version, one bug need to be fixed

Browse files

Former-commit-id: ebeb78f42b55688a10d5b48df701b27a074da087

Files changed (1) hide show
  1. evaluation/alignment.py +89 -0
evaluation/alignment.py CHANGED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ sys.path.append('../src')
4
+ from srt_util.srt import SrtScript
5
+
6
+
7
+ def procedure(anchor,subsec,S_arr,subidx):
8
+
9
+ temp = subsec[subidx - 1]
10
+ print('------------------------------')
11
+ print(anchor)
12
+ print(temp)
13
+
14
+ cache_idx = 0
15
+ while subidx != cache_idx: # loop until alignment stablized
16
+ cache_idx = subidx # reinitialize cache
17
+ # Inside interval
18
+ if subidx >= len(subsec): continue
19
+ sub = subsec[subidx]
20
+ if (anchor.end < sub.start): continue
21
+ if (anchor.start < sub.start) & (sub.end < anchor.end):
22
+ S_arr[len(S_arr) - 1] += sub.source_text
23
+ subidx += 1
24
+ elif anchor.end - sub.start > sub.end - anchor.start:
25
+ S_arr[len(S_arr) - 1] += sub.source_text
26
+ subidx += 1
27
+
28
+
29
+ print(sub)
30
+ print(S_arr[len(S_arr) - 1])
31
+ print('------------------------------')
32
+
33
+ subidx -= 1 # reset subidx to last segment
34
+
35
+
36
+ def alignment(pred_path,gt_path,threshold = 0.3):
37
+ pred = SrtScript.parse_from_srt_file(pred_path).segments
38
+ gt = SrtScript.parse_from_srt_file(gt_path).segments
39
+ pred_arr = []
40
+ gt_arr = []
41
+ duration = 0
42
+ #count = 0
43
+ #for ps,gs in zip(pred,gt):
44
+ # duration += ps.end + gs.end - ps.start - gs.start
45
+ # count += len(ps.source_text) + len(gs.source_text)
46
+ #density = count / duration #word density
47
+ idx_p, idx_t = -1, -1
48
+ while idx_p < len(pred) or idx_t < len(gt):
49
+ idx_p += 1
50
+ idx_t += 1
51
+ try:
52
+ ps = pred[idx_p]
53
+ gs = gt[idx_t]
54
+ except IndexError:
55
+ if idx_t >= len(gt):
56
+ pred_arr.append(ps.source_text)
57
+ continue
58
+ if idx_p >= len(pred):
59
+ gs = gt[idx_t]
60
+ gt_arr.append(gs.source_text)
61
+ continue
62
+ #print('init' + str(idx_t) + str(idx_p))
63
+ #duration
64
+ ps_dur = ps.end - ps.start
65
+ gs_dur = gs.end - gs.start
66
+ #forward/backward
67
+ if ps_dur <= gs_dur:
68
+ gt_arr.append(gs.source_text)
69
+ if gs.end < ps.start:
70
+ idx_p -= 1 # reset idx if no match
71
+ continue
72
+ pred_arr.append(ps.source_text)
73
+ idx_p += 1
74
+ procedure(gs,pred,pred_arr,idx_p)
75
+ else:
76
+ pred_arr.append(ps.source_text)
77
+ if ps.end < gs.start:
78
+ idx_t -= 1 # reset idx if no match
79
+ continue
80
+ gt_arr.append(gs.source_text)
81
+ idx_t += 1
82
+ procedure(ps,gt,gt_arr,idx_t)
83
+ #print(pred_arr)
84
+ #print(gt_arr)
85
+
86
+ return zip(pred_arr,gt_arr)
87
+
88
+
89
+ alignment('../results/OVB/OVB_en.srt','../results/OVM/OVM_en.srt')