Sem-nCG / utils.py
nbansal's picture
Fix Title
27a1559
raw
history blame
11.2 kB
import string
from typing import List, Union
import nltk
import torch
from numpy.typing import NDArray
from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
"""
Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0.
Args:
gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s):
- bool: If True, returns 0 if CUDA is available, otherwise returns "cpu".
- str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available
and the input is not "cpu", otherwise returns "cpu".
- int: Should be a valid GPU index. Returns the index if CUDA is available and valid,
otherwise returns "cpu".
- List[Union[str, int]]: List containing combinations of the str/int. Processes each
element and returns a list of corresponding results.
Returns:
Union[str, int, List[Union[str, int]]]: Depending on the input type:
- str: Returns "cpu" if no GPU is available or the input is "cpu".
- int: Returns the GPU index if valid and CUDA is available.
- List[Union[str, int]]: Returns a list of strings and/or integers based on the input list.
Raises:
ValueError: If the input gpu type is not recognized or invalid.
ValueError: If a string input is not one of ["cpu", "gpu", "cuda"].
ValueError: If an integer input is outside the valid range of GPU indices.
Notes:
- This function checks CUDA availability using torch.cuda.is_available() and counts
available GPUs using torch.cuda.device_count().
- Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda").
- The function ensures robust error handling for invalid input types or out-of-range indices.
"""
# Ensure gpu index is within the range of total available gpus
gpu_available = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
correct_strs = ["cpu", "gpu", "cuda"]
def _get_single_device(gpu_item):
if isinstance(gpu_item, bool):
return 0 if gpu_item and gpu_available else "cpu"
elif isinstance(gpu_item, str):
if gpu_item.lower() not in correct_strs:
raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}")
return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu"
elif isinstance(gpu_item, int):
if gpu_item >= gpu_count:
raise ValueError(
f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}"
)
return gpu_item if gpu_available else "cpu"
else:
raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.")
if isinstance(gpu, list):
seen_indices = set()
result = []
for item in gpu:
device = _get_single_device(item)
if isinstance(device, int):
if device not in seen_indices:
seen_indices.add(device)
result.append(device)
else:
result.append(device)
return result[0] if len(result) == 1 else result
else:
return _get_single_device(gpu)
def prep_sentences(sentences: List[str]) -> List[str]:
"""
Processes a list of sentences by stripping whitespace (at beginning and the end),
, filtering out empty sentences or sentences that only contains punctuations.
Args:
sentences (List[str]): A list of sentences to be processed.
Returns:
List[str]: A list of cleaned sentences
Raises:
ValueError: If the resulting list of sentences is empty.
Example:
>>> prep_sentences(["Hello, world!", " This is a test. ", "!!!"])
['Hello, world!', 'This is a test.']
>>> prep_sentences(["!!!", "..."])
ValueError: Document can't be empty.
"""
out = []
for sent in sentences:
sent = sent.strip()
sent_wo_punctuation = (
sent.translate(str.maketrans("", "", string.punctuation))
).strip()
if sent_wo_punctuation:
out.append(sent)
if len(out) == 0:
raise ValueError("Document can't be empty.")
return out
def tokenize_and_prep_document(document: Union[str, List[str]], tokenize: bool) -> List[str]:
"""
Tokenizes and prepares a document by either tokenizing it into sentences and processing each sentence,
or directly processing each element if `tokenize` is False.
Args:
document (Union[str, List[str]]): The document to be processed. It can be a single string (enitre document) or a
list of strings (list of sentences).
tokenize (bool): If True, tokenizes `document` into sentences using NLTK's sentence tokenizer before processing.
If False, processes each element of `document` directly as sentences.
Returns:
List[str]: A list of cleaned sentences.
Raises:
ValueError: If the resulting list of sentences is empty after processing.
Example:
>>> tokenize_and_prep_document("Hello, world! This is a test.", True)
['Hello, world!', 'This is a test.']
>>> tokenize_and_prep_document(["Hello, world!", "This is a test."], False)
['Hello, world!', 'This is a test.']
>>> tokenize_and_prep_document("!!! ...", True)
ValueError: Document can't be empty.
Note: Only the following two cases are possible.
tokenizer=True -> document: str
tokenizer=False -> document: List[str].
"""
if tokenize:
return prep_sentences(nltk.tokenize.sent_tokenize(document))
return prep_sentences(document)
def flatten_list(nested_list: list) -> list:
"""
Recursively flattens a nested list of any depth.
Parameters:
nested_list (list): The nested list to flatten.
Returns:
list: A flat list containing all the elements of the nested list.
"""
flat_list = []
for item in nested_list:
if isinstance(item, list):
flat_list.extend(flatten_list(item))
else:
flat_list.append(item)
return flat_list
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
"""
Check if the given object is a nested list of a specific type up to a specified depth.
Args:
- lst_obj: The object to check, expected to be a list or a single element.
- element_type: The type that each element in the nested list should match.
- depth (int): The depth of nesting to check. Must be non-negative.
Returns:
- bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
Raises:
- ValueError: If depth is negative.
Example:
```python
# Test cases
is_nested_list_of_type("test", str, 0) # Returns True
is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
```
Explanation:
- The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
- If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
- If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
- Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
"""
if depth == 0:
return isinstance(lst_obj, element_type)
elif depth > 0:
return isinstance(lst_obj, list) and all(
is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
else:
raise ValueError("Depth can't be negative")
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
"""
Slice embeddings into segments based on the provided number of sentences per segment.
Args:
- embeddings (np.ndarray): The array of embeddings to be sliced.
- num_sentences (Union[List[int], List[List[int]]]):
- If a list of integers: Specifies the number of embeddings to take in each slice.
- If a list of lists of integers: Specifies multiple nested levels of slicing.
Returns:
- List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.
Raises:
- TypeError: If `num_sentences` is not of type List[int] or List[List[int]].
Example Usage:
```python
embeddings = np.random.rand(10, 5)
num_sentences = [3, 2, 5]
result = slice_embeddings(embeddings, num_sentences)
# `result` will be a list of numpy arrays:
# [embeddings[:3], embeddings[3:5], embeddings[5:]]
num_sentences_nested = [[2, 1], [3, 4]]
result_nested = slice_embeddings(embeddings, num_sentences_nested)
# `result_nested` will be a nested list of numpy arrays:
# [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
slice_embeddings(embeddings, "invalid") # Raises a TypeError
```
"""
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
"""
Helper function to slice embeddings starting from index `s_idx`.
Args:
- s_idx (int): Starting index for slicing.
- n_sentences (List[int]): List specifying number of sentences in each slice.
Returns:
- Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
"""
_result = []
for count in n_sentences:
_result.append(embeddings[s_idx:s_idx + count])
s_idx += count
return _result, s_idx
if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences):
result, _ = _slice_embeddings(0, num_sentences)
return result
elif isinstance(num_sentences, list) and all(
isinstance(sublist, list) and all(
isinstance(item, int) for item in sublist
)
for sublist in num_sentences
):
nested_result = []
start_idx = 0
for nested_num_sentences in num_sentences:
embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences)
nested_result.append(embedding_slice)
return nested_result
else:
raise TypeError(f"Incorrect Type for {num_sentences=}")