Ahsen Khaliq commited on
Commit
4444ae6
1 Parent(s): ade4ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -224,25 +224,24 @@ def generate2(
224
  return generated_list[0]
225
 
226
 
 
 
 
227
 
228
-
229
- def inference(img,model):
230
  is_gpu = False
231
 
232
- device = CUDA(0) if is_gpu else "cpu"
233
- clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
234
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
235
-
236
-
237
  prefix_length = 10
238
 
239
  model = ClipCaptionModel(prefix_length)
240
 
241
- if model == "COCO":
242
  model_path = 'coco_weights.pt'
243
  else:
244
  model_path = 'conceptual_weights.pt'
245
  model.load_state_dict(torch.load(model_path, map_location=CPU))
 
 
246
  model = model.to(device)
247
 
248
  use_beam_search = False
 
224
  return generated_list[0]
225
 
226
 
227
+ device = CUDA(0) if is_gpu else "cpu"
228
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
229
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
230
 
231
+ def inference(img,model_name):
 
232
  is_gpu = False
233
 
 
 
 
 
 
234
  prefix_length = 10
235
 
236
  model = ClipCaptionModel(prefix_length)
237
 
238
+ if model_name == "COCO":
239
  model_path = 'coco_weights.pt'
240
  else:
241
  model_path = 'conceptual_weights.pt'
242
  model.load_state_dict(torch.load(model_path, map_location=CPU))
243
+ model = model.eval()
244
+ device = CUDA(0) if is_gpu else "cpu"
245
  model = model.to(device)
246
 
247
  use_beam_search = False