Spaces:
Sleeping
Sleeping
import statistics | |
import string | |
import sys | |
from dataclasses import dataclass, field | |
from typing import List, Union, Tuple | |
import nltk | |
import torch | |
from numpy.typing import NDArray | |
from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType | |
def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE: | |
""" | |
Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0. | |
Args: | |
gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s): | |
- bool: If True, returns 0 if CUDA is available, otherwise returns "cpu". | |
- str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available | |
and the input is not "cpu", otherwise returns "cpu". | |
- int: Should be a valid GPU index. Returns the index if CUDA is available and valid, | |
otherwise returns "cpu". | |
- List[Union[str, int]]: List containing combinations of the str/int. Processes each | |
element and returns a list of corresponding results. | |
Returns: | |
Union[str, int, List[Union[str, int]]]: Depending on the input type: | |
- str: Returns "cpu" if no GPU is available or the input is "cpu". | |
- int: Returns the GPU index if valid and CUDA is available. | |
- List[Union[str, int]]: Returns a list of strings and/or integers based on the input list. | |
Raises: | |
ValueError: If the input gpu type is not recognized or invalid. | |
ValueError: If a string input is not one of ["cpu", "gpu", "cuda"]. | |
ValueError: If an integer input is outside the valid range of GPU indices. | |
Notes: | |
- This function checks CUDA availability using torch.cuda.is_available() and counts | |
available GPUs using torch.cuda.device_count(). | |
- Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda"). | |
- The function ensures robust error handling for invalid input types or out-of-range indices. | |
""" | |
# Ensure gpu index is within the range of total available gpus | |
gpu_available = torch.cuda.is_available() | |
gpu_count = torch.cuda.device_count() | |
correct_strs = ["cpu", "gpu", "cuda"] | |
def _get_single_device(gpu_item): | |
if isinstance(gpu_item, bool): | |
return 0 if gpu_item and gpu_available else "cpu" | |
elif isinstance(gpu_item, str): | |
if gpu_item.lower() not in correct_strs: | |
raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}") | |
return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu" | |
elif isinstance(gpu_item, int): | |
if gpu_item >= gpu_count: | |
raise ValueError( | |
f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}" | |
) | |
return gpu_item if gpu_available else "cpu" | |
else: | |
raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.") | |
if isinstance(gpu, list): | |
seen_indices = set() | |
result = [] | |
for item in gpu: | |
device = _get_single_device(item) | |
if isinstance(device, int): | |
if device not in seen_indices: | |
seen_indices.add(device) | |
result.append(device) | |
else: | |
result.append(device) | |
return result[0] if len(result) == 1 else result | |
else: | |
return _get_single_device(gpu) | |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType: | |
""" | |
Slice embeddings into segments based on the provided number of sentences per segment. | |
Args: | |
- embeddings (np.ndarray): The array of embeddings to be sliced. | |
- num_sentences (Union[List[int], List[List[int]]]): | |
- If a list of integers: Specifies the number of embeddings to take in each slice. | |
- If a list of lists of integers: Specifies multiple nested levels of slicing. | |
Returns: | |
- List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings. | |
Raises: | |
- TypeError: If `num_sentences` is not of type List[int] or List[List[int]]. | |
Example Usage: | |
```python | |
embeddings = np.random.rand(10, 5) | |
num_sentences = [3, 2, 5] | |
result = slice_embeddings(embeddings, num_sentences) | |
# `result` will be a list of numpy arrays: | |
# [embeddings[:3], embeddings[3:5], embeddings[5:]] | |
num_sentences_nested = [[2, 1], [3, 4]] | |
result_nested = slice_embeddings(embeddings, num_sentences_nested) | |
# `result_nested` will be a nested list of numpy arrays: | |
# [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]] | |
slice_embeddings(embeddings, "invalid") # Raises a TypeError | |
``` | |
""" | |
def _slice_embeddings(s_idx: int, n_sentences: List[int]): | |
""" | |
Helper function to slice embeddings starting from index `s_idx`. | |
Args: | |
- s_idx (int): Starting index for slicing. | |
- n_sentences (List[int]): List specifying number of sentences in each slice. | |
Returns: | |
- Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index. | |
""" | |
_result = [] | |
for count in n_sentences: | |
_result.append(embeddings[s_idx:s_idx + count]) | |
s_idx += count | |
return _result, s_idx | |
if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences): | |
result, _ = _slice_embeddings(0, num_sentences) | |
return result | |
elif isinstance(num_sentences, list) and all( | |
isinstance(sublist, list) and all( | |
isinstance(item, int) for item in sublist | |
) | |
for sublist in num_sentences | |
): | |
nested_result = [] | |
start_idx = 0 | |
for nested_num_sentences in num_sentences: | |
embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences) | |
nested_result.append(embedding_slice) | |
return nested_result | |
else: | |
raise TypeError(f"Incorrect Type for {num_sentences=}") | |
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> Tuple[bool, str]: | |
""" | |
Check if the given object is a nested list of a specific type up to a specified depth. | |
Args: | |
- lst_obj: The object to check, expected to be a list or a single element. | |
- element_type: The type that each element in the nested list should match. | |
- depth (int): The depth of nesting to check. Must be non-negative. | |
Returns: | |
- Tuple[bool, str]: A tuple containing: | |
- A boolean indicating if lst_obj is a nested list of the specified type up to the given depth. | |
- A string containing an error message if the check fails, or an empty string if the check passes. | |
Raises: | |
- ValueError: If depth is negative. | |
Example: | |
```python | |
# Test cases | |
is_nested_list_of_type("test", str, 0) # Returns (True, "") | |
is_nested_list_of_type([1, 2, 3], str, 0) # Returns (False, "Element is of type int, expected type str.") | |
is_nested_list_of_type(["apple", "banana"], str, 1) # Returns (True, "") | |
is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns (True, "") | |
is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns (False, "Element at index 1 is of incorrect type.") | |
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns (True, "") | |
``` | |
Explanation: | |
- The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep. | |
- If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`. | |
- If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match | |
`element_type`. | |
- Returns a tuple containing a boolean and an error message. The boolean is `True` if `lst_obj` matches the | |
criteria, `False` otherwise. The error message provides details if the check fails. | |
- Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer. | |
""" | |
orig_depth = depth | |
def _is_nested_list_of_type(lst_o, e_type, d) -> Tuple[bool, str]: | |
if d == 0: | |
if isinstance(lst_o, e_type): | |
return True, "" | |
else: | |
return False, f"Element is of type {type(lst_o).__name__}, expected type {e_type.__name__}." | |
elif d > 0: | |
if isinstance(lst_o, list): | |
for i, item in enumerate(lst_o): | |
is_valid, err = _is_nested_list_of_type(item, e_type, d - 1) | |
if not is_valid: | |
msg = (f"Element at index {i} has incorrect type.\nGiven Element at index {i}: {lst_o[i]}" | |
f"\n{err}") if d == orig_depth else err | |
return False, msg | |
return True, "" | |
else: | |
return False, f"Object is not a list but {type(lst_o)}." | |
else: | |
raise ValueError("Depth can't be negative") | |
return _is_nested_list_of_type(lst_obj, element_type, depth) | |
def flatten_list(nested_list: list) -> list: | |
""" | |
Recursively flattens a nested list of any depth. | |
Parameters: | |
nested_list (list): The nested list to flatten. | |
Returns: | |
list: A flat list containing all the elements of the nested list. | |
""" | |
flat_list = [] | |
for item in nested_list: | |
if isinstance(item, list): | |
flat_list.extend(flatten_list(item)) | |
else: | |
flat_list.append(item) | |
return flat_list | |
def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float: | |
""" | |
Computes F1 value | |
:param p: Precision Value | |
:param r: Recall Value | |
:param eps: Epsilon Value | |
:return: | |
""" | |
f1 = 2 * p * r / (p + r + eps) | |
return f1 | |
def sent_tokenize(text: str) -> List[str]: | |
""" | |
Tokenizes the input text into a list of sentences. | |
This function uses the NLTK library's sentence tokenizer to split the input | |
text into individual sentences. Leading and trailing whitespace is removed | |
from the input text before tokenization. If the input text is empty or consists | |
only of whitespace, a list containing an empty string is returned. | |
Args: | |
text (str): The input text to be tokenized into sentences. | |
Returns: | |
List[str]: A list of sentences tokenized from the input text. | |
""" | |
text = text.strip() | |
if text == "": | |
return [""] | |
return [sent.strip() for sent in nltk.tokenize.sent_tokenize(text)] | |
class Scores: | |
""" | |
Data class representing evaluation scores including precision, recall, and computed F1 score. | |
Attributes: | |
- precision (float): The precision score for the evaluation. | |
- recall (List[float]): List of recall scores for each reference | |
- f1 (float): Computed F1 score based on the precision and mean recall values. | |
""" | |
precision: float | |
recall: List[float] | |
f1: float = field(init=False) | |
def __post_init__(self): | |
self.f1 = compute_f1(self.precision, statistics.fmean(self.recall)) | |