Elron commited on
Commit
6b89c25
·
verified ·
1 Parent(s): dcd3b86

Upload random_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. random_utils.py +28 -44
random_utils.py CHANGED
@@ -1,50 +1,34 @@
1
- import contextlib
2
  import random as python_random
3
- import string
4
- import threading
5
 
6
  __default_seed__ = 42
7
- _thread_local = threading.local()
8
 
9
-
10
- def get_seed():
11
- try:
12
- return _thread_local.seed
13
- except AttributeError:
14
- _thread_local.seed = __default_seed__
15
- return _thread_local.seed
16
-
17
-
18
- def get_random():
19
- try:
20
- return _thread_local.random
21
- except AttributeError:
22
- _thread_local.random = python_random.Random(get_seed())
23
- return _thread_local.random
24
-
25
-
26
- random = get_random()
27
 
28
 
29
- def set_seed(seed):
30
- _thread_local.seed = seed
31
- get_random().seed(seed)
32
-
33
-
34
- def get_random_string(length):
35
- letters = string.ascii_letters
36
- return "".join(get_random().choice(letters) for _ in range(length))
37
-
38
-
39
- @contextlib.contextmanager
40
- def nested_seed(sub_seed=None):
41
- old_state = get_random().getstate()
42
- old_global_seed = get_seed()
43
- sub_seed = sub_seed or get_random_string(10)
44
- new_global_seed = str(old_global_seed) + "/" + sub_seed
45
- set_seed(new_global_seed)
46
- try:
47
- yield get_random()
48
- finally:
49
- set_seed(old_global_seed)
50
- get_random().setstate(old_state)
 
 
 
 
 
1
+ import hashlib
2
  import random as python_random
 
 
3
 
4
  __default_seed__ = 42
 
5
 
6
+ from typing import Any, Hashable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
+ def get_seed():
10
+ return __default_seed__
11
+
12
+
13
+ def new_random_generator(sub_seed: Any) -> python_random.Random:
14
+ """Get a generator based on a seed derived from the default seed.
15
+
16
+ The purpose is to have a random generator that provides outputs
17
+ that are independent of previous randomizations.
18
+ """
19
+ if not isinstance(sub_seed, Hashable):
20
+ # e.g. for lists or dicts
21
+ # Create a persistent hash for the input object (using plain hash(..) produces
22
+ # a value that varies between runs)
23
+ sub_seed_str = str(sub_seed).encode("utf-8")
24
+ # limit the hash int size to 2^32
25
+ sub_seed_hexdigest = hashlib.md5(sub_seed_str).hexdigest()[:8]
26
+ # convert to int, from base 16:
27
+ sub_seed_int = int(sub_seed_hexdigest, 16)
28
+ sub_seed = str(sub_seed_int)
29
+ elif not isinstance(sub_seed, str):
30
+ # for Hashable objects that are not strings
31
+ sub_seed = str(hash(sub_seed))
32
+
33
+ sub_default_seed = str(__default_seed__) + "/" + sub_seed
34
+ return python_random.Random(sub_default_seed)