Alic-Li commited on
Commit
36aecf2
·
verified ·
1 Parent(s): 971710e

Update infer/worldmodel.py

Browse files
Files changed (1) hide show
  1. infer/worldmodel.py +79 -79
infer/worldmodel.py CHANGED
@@ -1,79 +1,79 @@
1
-
2
- import numpy as np
3
-
4
- import os, sys, torch, time
5
- import numpy as np
6
- np.set_printoptions(precision=4, suppress=True, linewidth=200)
7
- import torch
8
- print(torch.__version__)
9
- print(torch.version.cuda)
10
-
11
- # set these before import RWKV
12
- # os.environ['RWKV_JIT_ON'] = '1'
13
- # os.environ["RWKV_CUDA_ON"] = '1' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
14
- from infer.rwkv.model import RWKV # pip install rwkv
15
- from infer.rwkv.utils import PIPELINE, PIPELINE_ARGS
16
-
17
-
18
- from world.world_encoder import WorldEncoder
19
-
20
- class Worldinfer():
21
- def __init__(self, model_path, encoder_type, encoder_path, strategy='cpu bf16', args=None):
22
-
23
- ss = strategy.split(' ')
24
- DEVICE = ss[0]
25
- if ss[1] == 'fp16':
26
- self.DTYPE = torch.half
27
- elif ss[1] == 'fp32':
28
- self.DTYPE = torch.float32
29
- elif ss[1] == 'bf16':
30
- self.DTYPE = torch.bfloat16
31
- else:
32
- assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
33
-
34
- self.model_weight = torch.load(model_path + '.pth', map_location=DEVICE)
35
- modality_dict = {}
36
- for key, value in self.model_weight.items():
37
- if 'emb.weight' in key:
38
- _, n_embd = value.shape
39
- if 'modality' in key:
40
- k = key.replace('modality.world_encoder.', '')
41
- modality_dict[k] = value
42
- model = RWKV(model=self.model_weight, strategy=strategy)
43
- self.pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
44
-
45
- if args==None:
46
- self.args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.0, top_k=0, # top_k = 0 then ignore
47
- alpha_frequency = 0.0,
48
- alpha_presence = 0.0,
49
- token_ban = [0], # ban the generation of some tokens
50
- token_stop = [24], # stop generation whenever you see any token here
51
- chunk_len = 256) # split input into chunks to save VRAM (shorter -> slower)
52
- else:
53
- self.args=args
54
- print('RWKV finish!!!')
55
-
56
- config = {
57
- 'encoder_type': encoder_type,
58
- 'encoder_path': encoder_path,
59
- 'project_dim' : n_embd
60
- }
61
- self.modality = WorldEncoder(**config).to(DEVICE, torch.bfloat16)
62
- self.modality.load_checkpoint(modality_dict)
63
-
64
-
65
- def generate(self, text, modality='none', state=None):
66
- if isinstance(modality, str):
67
- y=None
68
- else:
69
- y = self.modality(modality).to(self.DTYPE)
70
- result, state = self.pipeline.generate(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
71
- return result, state
72
-
73
- # def prefill(self, text, modality='none', state=None):
74
- # if isinstance(modality, str):
75
- # y=None
76
- # else:
77
- # y = self.modality(modality).to(self.DTYPE)
78
- # result, state = self.pipeline.forward(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
79
- # return result, state
 
1
+
2
+ import numpy as np
3
+
4
+ import os, sys, torch, time
5
+ import numpy as np
6
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
7
+ import torch
8
+ print(torch.__version__)
9
+ print(torch.version.cuda)
10
+
11
+ # set these before import RWKV
12
+ # os.environ['RWKV_JIT_ON'] = '1'
13
+ # os.environ["RWKV_CUDA_ON"] = '1' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
14
+ from infer.rwkv.model import RWKV # pip install rwkv
15
+ from infer.rwkv.utils import PIPELINE, PIPELINE_ARGS
16
+
17
+
18
+ from world.world_encoder import WorldEncoder
19
+
20
+ class Worldinfer():
21
+ def __init__(self, model_path, encoder_type, encoder_path, strategy='cpu fp16', args=None):
22
+
23
+ ss = strategy.split(' ')
24
+ DEVICE = ss[0]
25
+ if ss[1] == 'fp16':
26
+ self.DTYPE = torch.half
27
+ elif ss[1] == 'fp32':
28
+ self.DTYPE = torch.float32
29
+ elif ss[1] == 'bf16':
30
+ self.DTYPE = torch.bfloat16
31
+ else:
32
+ assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
33
+
34
+ self.model_weight = torch.load(model_path + '.pth', map_location=DEVICE)
35
+ modality_dict = {}
36
+ for key, value in self.model_weight.items():
37
+ if 'emb.weight' in key:
38
+ _, n_embd = value.shape
39
+ if 'modality' in key:
40
+ k = key.replace('modality.world_encoder.', '')
41
+ modality_dict[k] = value
42
+ model = RWKV(model=self.model_weight, strategy=strategy)
43
+ self.pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
44
+
45
+ if args==None:
46
+ self.args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.0, top_k=0, # top_k = 0 then ignore
47
+ alpha_frequency = 0.0,
48
+ alpha_presence = 0.0,
49
+ token_ban = [0], # ban the generation of some tokens
50
+ token_stop = [24], # stop generation whenever you see any token here
51
+ chunk_len = 256) # split input into chunks to save VRAM (shorter -> slower)
52
+ else:
53
+ self.args=args
54
+ print('RWKV finish!!!')
55
+
56
+ config = {
57
+ 'encoder_type': encoder_type,
58
+ 'encoder_path': encoder_path,
59
+ 'project_dim' : n_embd
60
+ }
61
+ self.modality = WorldEncoder(**config).to(DEVICE, torch.bfloat16)
62
+ self.modality.load_checkpoint(modality_dict)
63
+
64
+
65
+ def generate(self, text, modality='none', state=None):
66
+ if isinstance(modality, str):
67
+ y=None
68
+ else:
69
+ y = self.modality(modality).to(self.DTYPE)
70
+ result, state = self.pipeline.generate(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
71
+ return result, state
72
+
73
+ # def prefill(self, text, modality='none', state=None):
74
+ # if isinstance(modality, str):
75
+ # y=None
76
+ # else:
77
+ # y = self.modality(modality).to(self.DTYPE)
78
+ # result, state = self.pipeline.forward(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
79
+ # return result, state