| | import numpy as np |
| |
|
| | from keras.src.api_export import keras_export |
| |
|
| |
|
| | @keras_export( |
| | [ |
| | "keras.utils.pad_sequences", |
| | "keras.preprocessing.sequence.pad_sequences", |
| | ] |
| | ) |
| | def pad_sequences( |
| | sequences, |
| | maxlen=None, |
| | dtype="int32", |
| | padding="pre", |
| | truncating="pre", |
| | value=0.0, |
| | ): |
| | """Pads sequences to the same length. |
| | |
| | This function transforms a list (of length `num_samples`) |
| | of sequences (lists of integers) |
| | into a 2D NumPy array of shape `(num_samples, num_timesteps)`. |
| | `num_timesteps` is either the `maxlen` argument if provided, |
| | or the length of the longest sequence in the list. |
| | |
| | Sequences that are shorter than `num_timesteps` |
| | are padded with `value` until they are `num_timesteps` long. |
| | |
| | Sequences longer than `num_timesteps` are truncated |
| | so that they fit the desired length. |
| | |
| | The position where padding or truncation happens is determined by |
| | the arguments `padding` and `truncating`, respectively. |
| | Pre-padding or removing values from the beginning of the sequence is the |
| | default. |
| | |
| | >>> sequence = [[1], [2, 3], [4, 5, 6]] |
| | >>> keras.utils.pad_sequences(sequence) |
| | array([[0, 0, 1], |
| | [0, 2, 3], |
| | [4, 5, 6]], dtype=int32) |
| | |
| | >>> keras.utils.pad_sequences(sequence, value=-1) |
| | array([[-1, -1, 1], |
| | [-1, 2, 3], |
| | [ 4, 5, 6]], dtype=int32) |
| | |
| | >>> keras.utils.pad_sequences(sequence, padding='post') |
| | array([[1, 0, 0], |
| | [2, 3, 0], |
| | [4, 5, 6]], dtype=int32) |
| | |
| | >>> keras.utils.pad_sequences(sequence, maxlen=2) |
| | array([[0, 1], |
| | [2, 3], |
| | [5, 6]], dtype=int32) |
| | |
| | Args: |
| | sequences: List of sequences (each sequence is a list of integers). |
| | maxlen: Optional Int, maximum length of all sequences. If not provided, |
| | sequences will be padded to the length of the longest individual |
| | sequence. |
| | dtype: (Optional, defaults to `"int32"`). Type of the output sequences. |
| | To pad sequences with variable length strings, you can use `object`. |
| | padding: String, "pre" or "post" (optional, defaults to `"pre"`): |
| | pad either before or after each sequence. |
| | truncating: String, "pre" or "post" (optional, defaults to `"pre"`): |
| | remove values from sequences larger than |
| | `maxlen`, either at the beginning or at the end of the sequences. |
| | value: Float or String, padding value. (Optional, defaults to `0.`) |
| | |
| | Returns: |
| | NumPy array with shape `(len(sequences), maxlen)` |
| | """ |
| | if not hasattr(sequences, "__len__"): |
| | raise ValueError("`sequences` must be iterable.") |
| | num_samples = len(sequences) |
| |
|
| | lengths = [] |
| | sample_shape = () |
| | flag = True |
| |
|
| | |
| | |
| |
|
| | for x in sequences: |
| | try: |
| | lengths.append(len(x)) |
| | if flag and len(x): |
| | sample_shape = np.asarray(x).shape[1:] |
| | flag = False |
| | except TypeError as e: |
| | raise ValueError( |
| | "`sequences` must be a list of iterables. " |
| | f"Found non-iterable: {str(x)}" |
| | ) from e |
| |
|
| | if maxlen is None: |
| | maxlen = np.max(lengths) |
| |
|
| | is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype( |
| | dtype, np.str_ |
| | ) |
| | if isinstance(value, str) and dtype is not object and not is_dtype_str: |
| | raise ValueError( |
| | f"`dtype` {dtype} is not compatible with `value`'s type: " |
| | f"{type(value)}\nYou should set `dtype=object` for variable length " |
| | "strings." |
| | ) |
| |
|
| | x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) |
| | for idx, s in enumerate(sequences): |
| | if not len(s): |
| | continue |
| | if truncating == "pre": |
| | trunc = s[-maxlen:] |
| | elif truncating == "post": |
| | trunc = s[:maxlen] |
| | else: |
| | raise ValueError(f'Truncating type "{truncating}" not understood') |
| |
|
| | |
| | trunc = np.asarray(trunc, dtype=dtype) |
| | if trunc.shape[1:] != sample_shape: |
| | raise ValueError( |
| | f"Shape of sample {trunc.shape[1:]} of sequence at " |
| | f"position {idx} is different from expected shape " |
| | f"{sample_shape}" |
| | ) |
| |
|
| | if padding == "post": |
| | x[idx, : len(trunc)] = trunc |
| | elif padding == "pre": |
| | x[idx, -len(trunc) :] = trunc |
| | else: |
| | raise ValueError(f'Padding type "{padding}" not understood') |
| | return x |
| |
|