File size: 1,094 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import random
from typing import List, Union

import torch


def convert_byte_str_to_str(s: str, encoding: str = "utf-8") -> str:
    """
    Extracts the actual string from a stringified bytes array (common in some webdatasets).

    Example: "b'hello world'" -> "hello world"
    """
    try:
        s = s[2:-1]
        s = s.encode("utf-8").decode(encoding)
    except (UnicodeDecodeError, UnicodeEncodeError, IndexError):
        pass
    return s


def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
    if random.random() >= dropout_p:
        return caption
    if isinstance(caption, str):
        return ""
    return [""] * len(caption)


def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
    if random.random() >= dropout_p:
        return embed
    embed = torch.zeros_like(embed)
    return embed


def remove_prefix(text: str, prefixes: List[str]) -> str:
    for prefix in prefixes:
        if text.startswith(prefix):
            return text.removeprefix(prefix).strip()
    return text