florentgbelidji HF staff commited on
Commit
a7a4721
1 Parent(s): baa2ff5

Changing weights and fixes

Browse files
model_large_caption.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d79b3b7c41478b5fe55c35b73ca6f3525a09708289371c6c0fac641e588287e
3
+ size 1785411505
models/blip_decoder.py CHANGED
@@ -8,8 +8,8 @@
8
  import warnings
9
  warnings.filterwarnings("ignore")
10
 
11
- from vit import VisionTransformer, interpolate_pos_embed
12
- from med import BertConfig, BertModel, BertLMHeadModel
13
  from transformers import BertTokenizer
14
 
15
  import torch
 
8
  import warnings
9
  warnings.filterwarnings("ignore")
10
 
11
+ from models.vit import VisionTransformer, interpolate_pos_embed
12
+ from models.med import BertConfig, BertModel, BertLMHeadModel
13
  from transformers import BertTokenizer
14
 
15
  import torch
pipeline.py CHANGED
@@ -10,12 +10,11 @@ from torchvision import transforms
10
  from torchvision.transforms.functional import InterpolationMode
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
- print(device)
14
 
15
  class PreTrainedPipeline():
16
- def __init__(self):
17
  # load the optimized model
18
- self.model_path = 'model_base_capfilt_large.pth'
19
  self.model = blip_decoder(
20
  pretrained=self.model_path,
21
  image_size=384,
@@ -34,7 +33,7 @@ class PreTrainedPipeline():
34
 
35
 
36
 
37
- def __call__(self, data: Any) -> Dict[str]:
38
  """
39
  Args:
40
  data (:obj:):
 
10
  from torchvision.transforms.functional import InterpolationMode
11
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
13
 
14
  class PreTrainedPipeline():
15
+ def __init__(self, path=""):
16
  # load the optimized model
17
+ self.model_path = 'model_large_caption.pth'
18
  self.model = blip_decoder(
19
  pretrained=self.model_path,
20
  image_size=384,
 
33
 
34
 
35
 
36
+ def __call__(self, data: Any) -> Dict[str, Any]:
37
  """
38
  Args:
39
  data (:obj:):