whisper-diarization / utils /diarize_utils.py
romsyflux
Added diarize_utils
e4ba58a
raw
history blame
4.35 kB
import numpy as np
def IoU(diarized_segments: np.ndarray, asr_segments: np.ndarray) -> np.ndarray:
"""
Calculates the Intersection over Union (IoU) between diarized_segments and asr_segments.
Args:
-----------
- diarized_segments (np.ndarray): An array representing N segments with shape (M, 2), where each row
contains the start and end times of a diarized segment.
- asr_segments (np.ndarray): An array representing M segments with shape (N, 2), where each row contains
the start and end times of an asr segment.
Returns:
--------
- np.ndarray: A 2D array of shape (N, M) representing the IoU between each pair of diarized and.
The value at position (i, j) in the array corresponds to the IoU between the asr segment i and the diarized segment j.
Values are in the range [0, 1], where 0 indicates no intersection and 1 indicates perfect overlap.
Note:
- The IoU is calculated as the ratio of the intersection over the union of the time intervals.
- Segments with no overlap result in an IoU value of 0.
- Segments with overlap but no intersection (e.g., one segment completely contained within another) can
have an IoU greater than 0.
Example:
```python
diarized_segments = np.array([[0, 5], [3, 8], [6, 10]])
asr_segments = np.array([[2, 6], [1, 4]])
IoU_values = IoU(diarized_segments, asr_segments)
print(IoU_values)
# Output
# [[0.5 0.5 0.]
# [0.6 0.14285714 0.]]
```
"""
# We measure intersection between each of the N asr_segments [Nx2] and each M of diarize_ segments [Mx2]
# The result is a NxM matrix. intersection <= 0 mean no intersection.
starts = np.maximum(asr_segments[:, 0, np.newaxis], diarized_segments[:, 0])
ends = np.minimum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1])
intersections = np.maximum(ends - starts, 0)
# Union for segments without overlap will lead to invalid results but it does not matters
# as we opt them out eventually.
union = np.maximum(asr_segments[:, 1, np.newaxis], diarized_segments[:, 1]) - np.minimum(
asr_segments[:, 0, np.newaxis], diarized_segments[:, 0]
)
# Negative results are zeroed as they are invalid.
intersection_over_union = np.maximum(intersections / union, 0)
return intersection_over_union
def match_segments(
diarized_segments: np.ndarray,
diarized_labels: list[str],
asr_segments: np.ndarray,
threshold: float = 0.0,
no_match_label: str = "NO_SPEAKER",
) -> np.ndarray:
"""
Perform segment matching between diarized segments and ASR (Automatic Speech Recognition) segments.
Args:
-----
- diarized_segments (np.ndarray): Array representing diarized speaker segments.
- diarized_labels (list[str]): List of labels corresponding to diarized_segments.
- asr_segments (np.ndarray): Array representing ASR speaker segments.
- threshold (float, optional): IoU (Intersection over Union) threshold for matching. Default is 0.0.
- no_match_label (str, optional): Label assigned when no matching segment is found. Default is "NO_SPEAKER".
Returns:
--------
- np.ndarray: Array of labels corresponding to the best-matched ASR segments for each diarized segment.
Notes:
- The function calculates IoU between diarized segments and ASR segments and considers only segments with IoU above the threshold.
- If no matching segment is found, the specified `no_match_label` is assigned.
- The returned array represents the labels of the best-matched ASR segments for each diarized segment.
"""
iou_results = IoU(diarized_segments, asr_segments)
# Zero out iou below threshold.
iou_results[iou_results <= threshold] = 0.0
# We create a no match label which value will be threshold
diarized_labels = [no_match_label] + diarized_labels
# If there is nothing above threshold, no_match_label will be assigned.
iou_results = np.hstack([threshold * np.ones((iou_results.shape[0], 1)), iou_results])
# Will find argument with highest iou (if all zeroes, will assign first (no_match_label)).
best_match_idx = np.argmax(iou_results, axis=1)
assigned_labels = np.take(diarized_labels, best_match_idx)
return assigned_labels