Spaces:
Runtime error
Runtime error
romsyflux
commited on
Commit
•
e4ba58a
1
Parent(s):
8dcaef9
Added diarize_utils
Browse files- utils/diarize_utils.py +95 -0
utils/diarize_utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def IoU(diarized_segments: np.ndarray, asr_segments: np.ndarray) -> np.ndarray:
|
5 |
+
"""
|
6 |
+
Calculates the Intersection over Union (IoU) between diarized_segments and asr_segments.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
-----------
|
10 |
+
- diarized_segments (np.ndarray): An array representing N segments with shape (M, 2), where each row
|
11 |
+
contains the start and end times of a diarized segment.
|
12 |
+
- asr_segments (np.ndarray): An array representing M segments with shape (N, 2), where each row contains
|
13 |
+
the start and end times of an asr segment.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
--------
|
17 |
+
- np.ndarray: A 2D array of shape (N, M) representing the IoU between each pair of diarized and.
|
18 |
+
The value at position (i, j) in the array corresponds to the IoU between the asr segment i and the diarized segment j.
|
19 |
+
Values are in the range [0, 1], where 0 indicates no intersection and 1 indicates perfect overlap.
|
20 |
+
|
21 |
+
Note:
|
22 |
+
- The IoU is calculated as the ratio of the intersection over the union of the time intervals.
|
23 |
+
- Segments with no overlap result in an IoU value of 0.
|
24 |
+
- Segments with overlap but no intersection (e.g., one segment completely contained within another) can
|
25 |
+
have an IoU greater than 0.
|
26 |
+
|
27 |
+
Example:
|
28 |
+
```python
|
29 |
+
diarized_segments = np.array([[0, 5], [3, 8], [6, 10]])
|
30 |
+
asr_segments = np.array([[2, 6], [1, 4]])
|
31 |
+
|
32 |
+
IoU_values = IoU(diarized_segments, asr_segments)
|
33 |
+
print(IoU_values)
|
34 |
+
# Output
|
35 |
+
# [[0.5 0.5 0.]
|
36 |
+
# [0.6 0.14285714 0.]]
|
37 |
+
```
|
38 |
+
"""
|
39 |
+
# We measure intersection between each of the N asr_segments [Nx2] and each M of diarize_ segments [Mx2]
|
40 |
+
# The result is a NxM matrix. intersection <= 0 mean no intersection.
|
41 |
+
starts = np.maximum(asr_segments[:, 0, np.newaxis], diarized_segments[:, 0])
|
42 |
+
ends = np.minimum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1])
|
43 |
+
intersections = np.maximum(ends - starts, 0)
|
44 |
+
|
45 |
+
# Union for segments without overlap will lead to invalid results but it does not matters
|
46 |
+
# as we opt them out eventually.
|
47 |
+
union = np.maximum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1]) - np.minimum(
|
48 |
+
asr_segments[:, 0, np.newaxis], diarized_segments[:, 0]
|
49 |
+
)
|
50 |
+
|
51 |
+
# Negative results are zeroed as they are invalid.
|
52 |
+
intersection_over_union = np.maximum(intersections / union, 0)
|
53 |
+
|
54 |
+
return intersection_over_union
|
55 |
+
|
56 |
+
|
57 |
+
def match_segments(
|
58 |
+
diarized_segments: np.ndarray,
|
59 |
+
diarized_labels: list[str],
|
60 |
+
asr_segments: np.ndarray,
|
61 |
+
threshold: float = 0.0,
|
62 |
+
no_match_label: str = "NO_SPEAKER",
|
63 |
+
) -> np.ndarray:
|
64 |
+
"""
|
65 |
+
Perform segment matching between diarized segments and ASR (Automatic Speech Recognition) segments.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
-----
|
69 |
+
- diarized_segments (np.ndarray): Array representing diarized speaker segments.
|
70 |
+
- diarized_labels (list[str]): List of labels corresponding to diarized_segments.
|
71 |
+
- asr_segments (np.ndarray): Array representing ASR speaker segments.
|
72 |
+
- threshold (float, optional): IoU (Intersection over Union) threshold for matching. Default is 0.0.
|
73 |
+
- no_match_label (str, optional): Label assigned when no matching segment is found. Default is "NO_SPEAKER".
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
--------
|
77 |
+
- np.ndarray: Array of labels corresponding to the best-matched ASR segments for each diarized segment.
|
78 |
+
|
79 |
+
Notes:
|
80 |
+
- The function calculates IoU between diarized segments and ASR segments and considers only segments with IoU above the threshold.
|
81 |
+
- If no matching segment is found, the specified `no_match_label` is assigned.
|
82 |
+
- The returned array represents the labels of the best-matched ASR segments for each diarized segment.
|
83 |
+
"""
|
84 |
+
iou_results = IoU(diarized_segments, asr_segments)
|
85 |
+
# Zero out iou below threshold.
|
86 |
+
iou_results[iou_results <= threshold] = 0.0
|
87 |
+
# We create a no match label which value will be threshold
|
88 |
+
diarized_labels = [no_match_label] + diarized_labels
|
89 |
+
# If there is nothing above threshold, no_match_label will be assigned.
|
90 |
+
iou_results = np.hstack([threshold * np.ones((iou_results.shape[0], 1)), iou_results])
|
91 |
+
# Will find argument with highest iou (if all zeroes, will assign first (no_match_label)).
|
92 |
+
best_match_idx = np.argmax(iou_results, axis=1)
|
93 |
+
assigned_labels = np.take(diarized_labels, best_match_idx)
|
94 |
+
|
95 |
+
return assigned_labels
|