m7mdal7aj commited on
Commit
dc81fd5
1 Parent(s): 36b5ae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -8,11 +8,17 @@ from PIL import Image
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
9
 
10
 
11
- def load_caption_model():
 
 
 
 
 
 
 
 
 
12
 
13
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
14
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
15
- #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
16
  return model, processor
17
 
18
 
 
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
9
 
10
 
11
+ def load_caption_model(blip2=false, instructblip=True):
12
+
13
+ if blip2:
14
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
15
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
16
+ #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
17
+
18
+ if instructblip:
19
+ model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16, device_map="auto")
20
+ processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
21
 
 
 
 
22
  return model, processor
23
 
24