romsyflux commited on
Commit
e4ba58a
1 Parent(s): 8dcaef9

Added diarize_utils

Browse files
Files changed (1) hide show
  1. utils/diarize_utils.py +95 -0
utils/diarize_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def IoU(diarized_segments: np.ndarray, asr_segments: np.ndarray) -> np.ndarray:
5
+ """
6
+ Calculates the Intersection over Union (IoU) between diarized_segments and asr_segments.
7
+
8
+ Args:
9
+ -----------
10
+ - diarized_segments (np.ndarray): An array representing N segments with shape (M, 2), where each row
11
+ contains the start and end times of a diarized segment.
12
+ - asr_segments (np.ndarray): An array representing M segments with shape (N, 2), where each row contains
13
+ the start and end times of an asr segment.
14
+
15
+ Returns:
16
+ --------
17
+ - np.ndarray: A 2D array of shape (N, M) representing the IoU between each pair of diarized and.
18
+ The value at position (i, j) in the array corresponds to the IoU between the asr segment i and the diarized segment j.
19
+ Values are in the range [0, 1], where 0 indicates no intersection and 1 indicates perfect overlap.
20
+
21
+ Note:
22
+ - The IoU is calculated as the ratio of the intersection over the union of the time intervals.
23
+ - Segments with no overlap result in an IoU value of 0.
24
+ - Segments with overlap but no intersection (e.g., one segment completely contained within another) can
25
+ have an IoU greater than 0.
26
+
27
+ Example:
28
+ ```python
29
+ diarized_segments = np.array([[0, 5], [3, 8], [6, 10]])
30
+ asr_segments = np.array([[2, 6], [1, 4]])
31
+
32
+ IoU_values = IoU(diarized_segments, asr_segments)
33
+ print(IoU_values)
34
+ # Output
35
+ # [[0.5 0.5 0.]
36
+ # [0.6 0.14285714 0.]]
37
+ ```
38
+ """
39
+ # We measure intersection between each of the N asr_segments [Nx2] and each M of diarize_ segments [Mx2]
40
+ # The result is a NxM matrix. intersection <= 0 mean no intersection.
41
+ starts = np.maximum(asr_segments[:, 0, np.newaxis], diarized_segments[:, 0])
42
+ ends = np.minimum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1])
43
+ intersections = np.maximum(ends - starts, 0)
44
+
45
+ # Union for segments without overlap will lead to invalid results but it does not matters
46
+ # as we opt them out eventually.
47
+ union = np.maximum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1]) - np.minimum(
48
+ asr_segments[:, 0, np.newaxis], diarized_segments[:, 0]
49
+ )
50
+
51
+ # Negative results are zeroed as they are invalid.
52
+ intersection_over_union = np.maximum(intersections / union, 0)
53
+
54
+ return intersection_over_union
55
+
56
+
57
+ def match_segments(
58
+ diarized_segments: np.ndarray,
59
+ diarized_labels: list[str],
60
+ asr_segments: np.ndarray,
61
+ threshold: float = 0.0,
62
+ no_match_label: str = "NO_SPEAKER",
63
+ ) -> np.ndarray:
64
+ """
65
+ Perform segment matching between diarized segments and ASR (Automatic Speech Recognition) segments.
66
+
67
+ Args:
68
+ -----
69
+ - diarized_segments (np.ndarray): Array representing diarized speaker segments.
70
+ - diarized_labels (list[str]): List of labels corresponding to diarized_segments.
71
+ - asr_segments (np.ndarray): Array representing ASR speaker segments.
72
+ - threshold (float, optional): IoU (Intersection over Union) threshold for matching. Default is 0.0.
73
+ - no_match_label (str, optional): Label assigned when no matching segment is found. Default is "NO_SPEAKER".
74
+
75
+ Returns:
76
+ --------
77
+ - np.ndarray: Array of labels corresponding to the best-matched ASR segments for each diarized segment.
78
+
79
+ Notes:
80
+ - The function calculates IoU between diarized segments and ASR segments and considers only segments with IoU above the threshold.
81
+ - If no matching segment is found, the specified `no_match_label` is assigned.
82
+ - The returned array represents the labels of the best-matched ASR segments for each diarized segment.
83
+ """
84
+ iou_results = IoU(diarized_segments, asr_segments)
85
+ # Zero out iou below threshold.
86
+ iou_results[iou_results <= threshold] = 0.0
87
+ # We create a no match label which value will be threshold
88
+ diarized_labels = [no_match_label] + diarized_labels
89
+ # If there is nothing above threshold, no_match_label will be assigned.
90
+ iou_results = np.hstack([threshold * np.ones((iou_results.shape[0], 1)), iou_results])
91
+ # Will find argument with highest iou (if all zeroes, will assign first (no_match_label)).
92
+ best_match_idx = np.argmax(iou_results, axis=1)
93
+ assigned_labels = np.take(diarized_labels, best_match_idx)
94
+
95
+ return assigned_labels