minchul commited on
Commit
f2be5ee
1 Parent(s): c25d296

Upload directory

Browse files
Files changed (1) hide show
  1. models/vit_kprpe/RPE/__init__.py +48 -0
models/vit_kprpe/RPE/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .KPRPE import kprpe_shared
2
+ import torch
3
+ import warnings
4
+ import subprocess
5
+ import sys
6
+ import os
7
+
8
+ try:
9
+ from .rpe_ops.rpe_index import RPEIndexFunction
10
+ except ImportError:
11
+ try:
12
+ # Attempt to install the module from the setup.py script
13
+ dirname = os.path.dirname(os.path.abspath(__file__))
14
+ cwd = os.getcwd()
15
+ os.chdir(os.path.join(dirname, 'rpe_ops'))
16
+ subprocess.check_call([sys.executable, 'setup.py', 'install', '--user'])
17
+ GREEN_STR = "\033[92m{}\033[00m"
18
+ print(GREEN_STR.format("\n[INFO] Successfully installed `rpe_ops`. Restart Application"),)
19
+ sys.exit()
20
+ except subprocess.CalledProcessError as install_error:
21
+ RED_STR = "\033[91m{}\033[00m"
22
+ warnings.warn(RED_STR.format("\n[WARNING] Failed to install `rpe_ops`. "
23
+ "Please check the installation script."),)
24
+ except ImportError as import_error:
25
+ RED_STR = "\033[91m{}\033[00m"
26
+ warnings.warn(RED_STR.format("\n[WARNING] The module `rpe_ops` is not built. "
27
+ "For better training performance, please build `rpe_ops`."),)
28
+
29
+
30
+ def build_rpe(rpe_config, head_dim, num_heads):
31
+ if rpe_config is None:
32
+ return None
33
+ else:
34
+ name = rpe_config.name
35
+ if name == 'KPRPE_shared':
36
+ rpe_config = kprpe_shared.get_rpe_config(
37
+ ratio=rpe_config.ratio,
38
+ method=rpe_config.method,
39
+ mode=rpe_config.mode,
40
+ shared_head=rpe_config.shared_head,
41
+ skip=0,
42
+ rpe_on=rpe_config.rpe_on,
43
+ )
44
+ return kprpe_shared.build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads)
45
+
46
+ else:
47
+ raise NotImplementedError(f"Unknow RPE: {name}")
48
+