Spaces:
Running
Running
Add doc strings
Browse files
utils.py
CHANGED
@@ -79,7 +79,50 @@ def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
|
|
79 |
|
80 |
|
81 |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
_result = []
|
84 |
for count in n_sentences:
|
85 |
_result.append(embeddings[s_idx:s_idx + count])
|
@@ -107,6 +150,37 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
|
|
107 |
|
108 |
|
109 |
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
if depth == 0:
|
111 |
return isinstance(lst_obj, element_type)
|
112 |
elif depth > 0:
|
|
|
79 |
|
80 |
|
81 |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
82 |
+
"""
|
83 |
+
Slice embeddings into segments based on the provided number of sentences per segment.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
- embeddings (np.ndarray): The array of embeddings to be sliced.
|
87 |
+
- num_sentences (Union[List[int], List[List[int]]]):
|
88 |
+
- If a list of integers: Specifies the number of embeddings to take in each slice.
|
89 |
+
- If a list of lists of integers: Specifies multiple nested levels of slicing.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
- List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.
|
93 |
+
|
94 |
+
Raises:
|
95 |
+
- TypeError: If `num_sentences` is not of type List[int] or List[List[int]].
|
96 |
+
|
97 |
+
Example Usage:
|
98 |
+
|
99 |
+
```python
|
100 |
+
embeddings = np.random.rand(10, 5)
|
101 |
+
num_sentences = [3, 2, 5]
|
102 |
+
result = slice_embeddings(embeddings, num_sentences)
|
103 |
+
# `result` will be a list of numpy arrays:
|
104 |
+
# [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
105 |
+
|
106 |
+
num_sentences_nested = [[2, 1], [3, 4]]
|
107 |
+
result_nested = slice_embeddings(embeddings, num_sentences_nested)
|
108 |
+
# `result_nested` will be a nested list of numpy arrays:
|
109 |
+
# [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
|
110 |
+
|
111 |
+
slice_embeddings(embeddings, "invalid") # Raises a TypeError
|
112 |
+
```
|
113 |
+
"""
|
114 |
+
|
115 |
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
|
116 |
+
"""
|
117 |
+
Helper function to slice embeddings starting from index `s_idx`.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
- s_idx (int): Starting index for slicing.
|
121 |
+
- n_sentences (List[int]): List specifying number of sentences in each slice.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
- Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
|
125 |
+
"""
|
126 |
_result = []
|
127 |
for count in n_sentences:
|
128 |
_result.append(embeddings[s_idx:s_idx + count])
|
|
|
150 |
|
151 |
|
152 |
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
153 |
+
"""
|
154 |
+
Check if the given object is a nested list of a specific type up to a specified depth.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
- lst_obj: The object to check, expected to be a list or a single element.
|
158 |
+
- element_type: The type that each element in the nested list should match.
|
159 |
+
- depth (int): The depth of nesting to check. Must be non-negative.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
- bool: True if lst_obj is a nested list of the specified type up to the given depth, False otherwise.
|
163 |
+
|
164 |
+
Raises:
|
165 |
+
- ValueError: If depth is negative.
|
166 |
+
|
167 |
+
Example:
|
168 |
+
```python
|
169 |
+
# Test cases
|
170 |
+
is_nested_list_of_type("test", str, 0) # Returns True
|
171 |
+
is_nested_list_of_type([1, 2, 3], str, 0) # Returns False
|
172 |
+
is_nested_list_of_type(["apple", "banana"], str, 1) # Returns True
|
173 |
+
is_nested_list_of_type([[1, 2], [3, 4]], int, 2) # Returns True
|
174 |
+
is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2) # Returns False
|
175 |
+
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3) # Returns True
|
176 |
+
```
|
177 |
+
|
178 |
+
Explanation:
|
179 |
+
- The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
|
180 |
+
- If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
|
181 |
+
- If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match `element_type`.
|
182 |
+
- Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
|
183 |
+
"""
|
184 |
if depth == 0:
|
185 |
return isinstance(lst_obj, element_type)
|
186 |
elif depth > 0:
|