m7mdal7aj commited on
Commit
63fc765
1 Parent(s): 85f811b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -4,19 +4,29 @@ import bitsandbytes
4
  import accelerate
5
  import scipy
6
  from PIL import Image
7
-
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
9
 
10
 
11
  def load_caption_model(blip2=False, instructblip=True):
12
 
 
 
 
 
13
  if blip2:
14
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
15
  model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
 
 
 
16
  #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
17
 
18
  if instructblip:
19
  model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
 
 
 
20
  processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
21
 
22
  return model, processor
 
4
  import accelerate
5
  import scipy
6
  from PIL import Image
7
+ import torch.nn as nn
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
9
 
10
 
11
  def load_caption_model(blip2=False, instructblip=True):
12
 
13
+
14
+ model = YourModel()
15
+
16
+
17
  if blip2:
18
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
19
  model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
20
+ if torch.cuda.device_count() > 1:
21
+ model = nn.DataParallel(model)
22
+ model.to('cuda')
23
  #model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
24
 
25
  if instructblip:
26
  model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
27
+ if torch.cuda.device_count() > 1:
28
+ model = nn.DataParallel(model)
29
+ model.to('cuda')
30
  processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
31
 
32
  return model, processor