xinyu1205 commited on
Commit
4b6c116
1 Parent(s): c606c83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from ram import get_transform, inference_ram, inference_tag2text
3
- from ram.models import ram, tag2text_caption
4
 
5
  ram_checkpoint = "./ram_swin_large_14m.pth"
6
  tag2text_checkpoint = "./tag2text_swin_14m.pth"
@@ -44,7 +44,7 @@ if __name__ == "__main__":
44
  # get transform and load models
45
  transform = get_transform(image_size=image_size)
46
  ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
47
- tag2text_model = tag2text_caption(
48
  pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
49
 
50
  # build GUI
 
1
  import torch
2
  from ram import get_transform, inference_ram, inference_tag2text
3
+ from ram.models import ram, tag2text
4
 
5
  ram_checkpoint = "./ram_swin_large_14m.pth"
6
  tag2text_checkpoint = "./tag2text_swin_14m.pth"
 
44
  # get transform and load models
45
  transform = get_transform(image_size=image_size)
46
  ram_model = ram(pretrained=ram_checkpoint, image_size=image_size, vit='swin_l').eval().to(device)
47
+ tag2text_model = tag2text(
48
  pretrained=tag2text_checkpoint, image_size=image_size, vit='swin_b').eval().to(device)
49
 
50
  # build GUI