File size: 3,466 Bytes
b3dbccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33245ac
b3dbccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Run ray tests.

poetry add protobuf="^3.20.1"
"""
import multiprocessing
import os
from multiprocessing import Pool
from pathlib import Path

# import joblib
import more_itertools as mit
import numpy as np
import ray
from about_time import about_time
from logzero import logger

from radio_embed import radio_embed

num_cpus = multiprocessing.cpu_count()
filename = "fangfang-en.txt"
lines = Path(filename).read_text("utf8").splitlines()
lst = [_.strip() for _ in lines if _.strip()]

# with about_time() as dur: res1 = radio_embed("\n".join(lst))
# 143.72 s
# ray.init(num_cpus=num_cpus)


def test_pool(func, args_):
    """Test."""
    with Pool(num_cpus) as pool:
        ret = pool.map(func, args_)
    # pool.close()
    # pool.join()
    return ret


args = "\n".join(lst)
args = ["\n".join(elm) for elm in mit.divide(num_cpus, lst)]
# with about_time() as dur2: res2 = test_pool(radio_embed, args)
# print(dur2.duration)
# 26.5s  about 1/6 of

# res2a = np.concatenate(res2)
# np.allclose(res1, res2a, rtol=1e-05, atol=1e-07)

# %timeit ret = joblib.Parallel(n_jobs=num_cpus, backend='loky', verbose=0)(joblib.delayed(radio_embed)(arg) for arg in args)
# 28.1 s ± 1.08 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
# with about_time() as dur4: ret4 = joblib.Parallel(n_jobs=num_cpus, backend='multiprocessing', verbose=0)(joblib.delayed(radio_embed)(arg) for arg in args)
# dur4.duration 28.48s
# ret4a = np.concatenate(ret4)
# assert np.allclose(res1, ret4a, rtol=1e-05, atol=1e-07)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

if not ray.is_initialized():
    ray.init(num_cpus=num_cpus)


@ray.remote
def ray_embed(text):
    """Embed text to d-512."""
    return radio_embed(text)


def main():
    """Run."""
    _ = """
    with about_time() as dur:
        res1 = radio_embed("\n".join(lst))
    print(dur.duration_human)
    # 143.72 s 137 s
    # """

    with about_time() as dur5:
        _ = [ray_embed.remote(arg) for arg in args]
        res5 = ray.get(_)
    print(num_cpus, dur5.duration_human)  # 40s
    res5a = np.concatenate(res5)
    # _ = np.allclose(res1, res5a, rtol=1e-05, atol=1e-07)
    # print(_)

    ray.shutdown()
    ray.init(num_cpus=num_cpus // 2)
    with about_time() as dur5a:
        _ = [ray_embed.remote(arg) for arg in args]
        res6 = ray.get(_)
    print(num_cpus // 2, dur5a.duration_human)  # 40s
    res6a = np.concatenate(res6)
    _ = np.allclose(res5a, res6a, rtol=1e-05, atol=1e-07)

    logger.info(" res5a allclose to res6a? %s", _)

    ray.shutdown()
    ray.init(num_cpus=2)
    with about_time() as dur7:
        _ = [ray_embed.remote(arg) for arg in args]
        res7 = ray.get(_)
    print(2, dur7.duration_human)  # 90s
    res7a = np.concatenate(res7)
    _ = np.allclose(res5a, res7a, rtol=1e-05, atol=1e-07)

    logger.info(" res5a allclose to res7a? %s", _)

    # num_cpus - 1
    ray.shutdown()
    ray.init(num_cpus=num_cpus - 1)
    with about_time() as dur8:
        _ = [ray_embed.remote(arg) for arg in args]
        res8 = ray.get(_)
    print(num_cpus - 1, dur8.duration_human)  # 44s
    res8a = np.concatenate(res8)
    _ = np.allclose(res5a, res8a, rtol=1e-05, atol=1e-07)

    logger.info(" res5a allclose to res8a? %s", _)

    print(num_cpus, dur5.duration_human)  # 32s
    print(num_cpus // 2, dur5a.duration_human)  # 38s
    print(2, dur7.duration_human)  # 90s
    print(num_cpus - 1, dur8.duration_human)  # 44s


if __name__ == "__main__":
    main()