merge
Browse files- app.py +9 -4
- models/epalm.py +1 -1
- models/timesformer.py +2 -0
- requirements.txt +8 -3
app.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
import os
|
2 |
|
3 |
os.system('cd TimeSformer;'
|
4 |
-
'
|
5 |
|
6 |
os.system('ls -l')
|
|
|
7 |
|
|
|
|
|
8 |
|
|
|
9 |
|
10 |
|
11 |
import torch
|
@@ -39,7 +43,8 @@ yaml=YAML(typ='safe')
|
|
39 |
|
40 |
|
41 |
use_cuda = torch.cuda.is_available()
|
42 |
-
device = torch.
|
|
|
43 |
|
44 |
## Load model
|
45 |
|
@@ -107,7 +112,7 @@ num_beams=3
|
|
107 |
max_length=30
|
108 |
|
109 |
|
110 |
-
|
111 |
|
112 |
|
113 |
def inference(image, audio, video, task_type, instruction):
|
@@ -129,7 +134,7 @@ def inference(image, audio, video, task_type, instruction):
|
|
129 |
|
130 |
|
131 |
|
132 |
-
with torch.autocast(device_type=
|
133 |
|
134 |
out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=max_length,
|
135 |
do_sample=do_sample, num_beams=num_beams)
|
|
|
1 |
import os
|
2 |
|
3 |
os.system('cd TimeSformer;'
|
4 |
+
'pip install .; cd ..')
|
5 |
|
6 |
os.system('ls -l')
|
7 |
+
os.system('pwd')
|
8 |
|
9 |
+
import os, sys
|
10 |
+
sys.path.append("/home/user/app/TimeSformer/")
|
11 |
|
12 |
+
import timesformer
|
13 |
|
14 |
|
15 |
import torch
|
|
|
43 |
|
44 |
|
45 |
use_cuda = torch.cuda.is_available()
|
46 |
+
device = torch.device('cuda') if use_cuda else torch.device('cpu')
|
47 |
+
device_type = 'cuda' if use_cuda else 'cpu'
|
48 |
|
49 |
## Load model
|
50 |
|
|
|
112 |
max_length=30
|
113 |
|
114 |
|
115 |
+
model.bfloat16()
|
116 |
|
117 |
|
118 |
def inference(image, audio, video, task_type, instruction):
|
|
|
134 |
|
135 |
|
136 |
|
137 |
+
with torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=True):
|
138 |
|
139 |
out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=max_length,
|
140 |
do_sample=do_sample, num_beams=num_beams)
|
models/epalm.py
CHANGED
@@ -211,7 +211,7 @@ class ePALM(nn.Module):
|
|
211 |
self.no_attention_mask = False
|
212 |
|
213 |
if low_cpu:
|
214 |
-
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt,
|
215 |
else:
|
216 |
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt)
|
217 |
|
|
|
211 |
self.no_attention_mask = False
|
212 |
|
213 |
if low_cpu:
|
214 |
+
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt, torch_dtype=torch.float16, low_cpu_mem_usage=False)
|
215 |
else:
|
216 |
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt)
|
217 |
|
models/timesformer.py
CHANGED
@@ -10,6 +10,8 @@ import warnings
|
|
10 |
import torch.nn.functional as F
|
11 |
import numpy as np
|
12 |
|
|
|
|
|
13 |
from timesformer.models.vit_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
14 |
from timesformer.models.helpers import load_pretrained
|
15 |
from timesformer.models.vit_utils import DropPath, to_2tuple, trunc_normal_
|
|
|
10 |
import torch.nn.functional as F
|
11 |
import numpy as np
|
12 |
|
13 |
+
|
14 |
+
|
15 |
from timesformer.models.vit_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
16 |
from timesformer.models.helpers import load_pretrained
|
17 |
from timesformer.models.vit_utils import DropPath, to_2tuple, trunc_normal_
|
requirements.txt
CHANGED
@@ -18,11 +18,10 @@ scikit_learn
|
|
18 |
scipy
|
19 |
sentencepiece
|
20 |
setuptools
|
21 |
-
|
22 |
-
slowfast
|
23 |
submitit
|
24 |
tensorflow
|
25 |
-
timm
|
26 |
torch
|
27 |
torchaudio
|
28 |
torchvision
|
@@ -32,6 +31,12 @@ torchtyping
|
|
32 |
tqdm
|
33 |
ruamel.yaml
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
# # accelerate==0.11.0
|
36 |
# apex==0.9.10.dev0
|
37 |
# av==10.0.0
|
|
|
18 |
scipy
|
19 |
sentencepiece
|
20 |
setuptools
|
21 |
+
scikit-image
|
|
|
22 |
submitit
|
23 |
tensorflow
|
24 |
+
timm==0.6.12
|
25 |
torch
|
26 |
torchaudio
|
27 |
torchvision
|
|
|
31 |
tqdm
|
32 |
ruamel.yaml
|
33 |
|
34 |
+
|
35 |
+
git+https://github.com/facebookresearch/fvcore
|
36 |
+
simplejson
|
37 |
+
psutil
|
38 |
+
|
39 |
+
|
40 |
# # accelerate==0.11.0
|
41 |
# apex==0.9.10.dev0
|
42 |
# av==10.0.0
|