File size: 1,076 Bytes
91d5a5e
 
 
 
 
 
 
 
7235a64
 
 
 
91d5a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7235a64
91d5a5e
 
 
7235a64
91d5a5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Union, Callable

import torch


class SeedSchemeFactory:
    registry = {}

    @classmethod
    def get_schemes_name(cls) -> list[str]:
        return list(cls.registry.keys())

    @classmethod
    def register(cls, name: str):
        """
        Register the hash scheme by name. Hash scheme must be callable.

        Args:
            name: name of seed scheme.
        """

        def wrapper(wrapped_class):
            if name in cls.registry:
                print(f"Override {name} in SeedSchemeFactory")
            cls.registry[name] = wrapped_class
            return wrapped_class

        return wrapper

    @classmethod
    def get_instance(cls, name: str, *args, **kwargs):
        """
        Get the hash scheme by name.

        Args:
            name: name of seed scheme.
        """
        if name in cls.registry:
            return cls.registry[name](*args, **kwargs)
        else:
            return None


class SeedScheme:
    def __call__(self, input_ids: torch.Tensor) -> int:
        return 0


from seed_schemes import *