merve HF staff commited on
Commit
6dc212b
1 Parent(s): c4055fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -24,8 +24,11 @@ MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
24
 
25
 
26
  processor = AutoProcessor.from_pretrained(MODEL_ID)
27
- model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto", load_in_8bit=True)
28
 
 
 
 
 
29
 
30
 
31
  def generate_caption(
 
24
 
25
 
26
  processor = AutoProcessor.from_pretrained(MODEL_ID)
 
27
 
28
+ if torch.cuda.is_available():
29
+ model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto", load_in_8bit=True)
30
+ else:
31
+ model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID)
32
 
33
 
34
  def generate_caption(