Spaces:
Sleeping
Sleeping
| import sys | |
| import numpy as np | |
| sys.path.append('../src') | |
| from srt_util.srt import SrtScript | |
| from srt_util.srt import SrtSegment | |
| # Helper method | |
| # Align sub anchor segment pair via greedy approach | |
| # Input: anchor segment, SRT segments, output array of sub, index of current sub | |
| # Output: updated index of sub | |
| def procedure(anchor, subsec, S_arr, subidx): | |
| cache_idx = 0 | |
| while subidx != cache_idx: # Terminate when alignment stablizes | |
| cache_idx = subidx | |
| # if sub segment runs out during the loop, terminate | |
| if subidx >= len(subsec): | |
| break | |
| sub = subsec[subidx] | |
| if anchor.end < sub.start: | |
| continue | |
| # If next sub has a heavier overlap compartment, add to current alignment | |
| if (anchor.start <= sub.start) and (sub.end <= anchor.end) or anchor.end - sub.start > sub.end - anchor.start: | |
| S_arr[-1] += sub#.source_text | |
| subidx += 1 | |
| return subidx - 1 # Reset last invalid update from loop | |
| # Input: path1, path2 | |
| # Output: aligned array of SRTsegment corresponding to path1 path2 | |
| # Note: Modify comment with .source_text to get output array with string only | |
| def alignment_obsolete(pred_path, gt_path): | |
| empt = SrtSegment([0,'00:00:00,000 --> 00:00:00,000','','','']) | |
| pred = SrtScript.parse_from_srt_file(pred_path).segments | |
| gt = SrtScript.parse_from_srt_file(gt_path).segments | |
| pred_arr, gt_arr = [], [] | |
| idx_p, idx_t = 0, 0 # idx_p: current index of pred segment, idx_t for ground truth | |
| while idx_p < len(pred) or idx_t < len(gt): | |
| # Check if one srt file runs out while reading | |
| ps = pred[idx_p] if idx_p < len(pred) else None | |
| gs = gt[idx_t] if idx_t < len(gt) else None | |
| if not ps: | |
| # If ps runs out, align gs segment with filler one by one | |
| gt_arr.append(gs)#.source_text | |
| pred_arr.append(empt) | |
| idx_t += 1 | |
| continue | |
| if not gs: | |
| # If gs runs out, align ps segment with filler one by one | |
| pred_arr.append(ps)#.source_text | |
| gt_arr.append(empt) | |
| idx_p += 1 | |
| continue | |
| ps_dur = ps.end - ps.start | |
| gs_dur = gs.end - gs.start | |
| # Check for duration to decide anchor and sub | |
| if ps_dur <= gs_dur: | |
| # Detect segment with no overlap | |
| if ps.end < gs.start: | |
| pred_arr.append(ps)#.source_text | |
| gt_arr.append(empt) # append filler | |
| idx_t -= 1 # reset ground truth index | |
| else: | |
| if gs.end >= ps.start: | |
| gt_arr.append(gs)#.source_text | |
| pred_arr.append(ps)#.source_text | |
| idx_p = procedure(gs, pred, pred_arr, idx_p + 1) | |
| else: | |
| gt_arr[len(gt_arr) - 1] += gs#.source_text | |
| #pred_arr.append(empt) | |
| idx_p -= 1 | |
| else: | |
| # same overlap checking procedure | |
| if gs.end < ps.start: | |
| gt_arr.append(gs)#.source_text | |
| pred_arr.append(empt) # filler | |
| idx_p -= 1 # reset | |
| else: | |
| if ps.end >= gs.start: | |
| pred_arr.append(ps)#.source_text | |
| gt_arr.append(gs)#.source_text | |
| idx_t = procedure(ps, gt, gt_arr, idx_t + 1) | |
| else: # filler pairing | |
| pred_arr[len(pred_arr) - 1] += ps | |
| idx_t -= 1 | |
| idx_p += 1 | |
| idx_t += 1 | |
| #for a in gt_arr: | |
| # print(a.translation) | |
| return zip(pred_arr, gt_arr) | |
| # Input: path1, path2, threshold = 0.5 sec by default | |
| # Output: aligned array of SRTsegment corresponding to path1 path2 | |
| def alignment(pred_path, gt_path, threshold=0.5): | |
| empt = SrtSegment([0, '00:00:00,000 --> 00:00:00,000', '', '', '']) | |
| pred = SrtScript.parse_from_srt_file(pred_path).segments | |
| gt = SrtScript.parse_from_srt_file(gt_path).segments | |
| pred_arr, gt_arr = [], [] | |
| idx_p, idx_t = 0, 0 | |
| while idx_p < len(pred) or idx_t < len(gt): | |
| ps = pred[idx_p] if idx_p < len(pred) else empt | |
| gs = gt[idx_t] if idx_t < len(gt) else empt | |
| # Merging sequence for pred | |
| while idx_p + 1 < len(pred) and pred[idx_p + 1].end <= gs.end + threshold: | |
| ps += pred[idx_p + 1] | |
| idx_p += 1 | |
| # Merging sequence for gt | |
| while idx_t + 1 < len(gt) and gt[idx_t + 1].end <= ps.end + threshold: | |
| gs += gt[idx_t + 1] | |
| idx_t += 1 | |
| # Append to the result arrays | |
| pred_arr.append(ps) | |
| gt_arr.append(gs) | |
| idx_p += 1 | |
| idx_t += 1 | |
| #for a in pred_arr: | |
| # print(a.translation) | |
| #for a in gt_arr: | |
| # print(a.source_text) | |
| return zip(pred_arr, gt_arr) | |
| # Test Case | |
| #alignment('test_translation_s2.srt', 'test_translation_zh.srt') | |