farrosalferro24 commited on
Commit
955a80d
Β·
verified Β·
1 Parent(s): 92bd08d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -40,7 +40,7 @@ print_keyword=True
40
  print_topk_patches = True
41
 
42
  torch_dtype = torch.float16
43
- attn_implementation = 'flash_attention_2'
44
  device_map = 'cuda'
45
 
46
  conv_template = conv_templates['llama_3']
@@ -56,7 +56,8 @@ config = GeckoConfig.from_pretrained(model,
56
  print_keyword=print_keyword)
57
  processor = GeckoProcessor.from_pretrained(model, config=config, use_keyword=True, cropping_method=cropping_method, crop_size=crop_size)
58
  model = GeckoForConditionalGeneration.from_pretrained(
59
- model, config=config)
 
60
  model.load_text_encoder(processor)
61
 
62
  @spaces.GPU
 
40
  print_topk_patches = True
41
 
42
  torch_dtype = torch.float16
43
+ attn_implementation = 'sdpa'
44
  device_map = 'cuda'
45
 
46
  conv_template = conv_templates['llama_3']
 
56
  print_keyword=print_keyword)
57
  processor = GeckoProcessor.from_pretrained(model, config=config, use_keyword=True, cropping_method=cropping_method, crop_size=crop_size)
58
  model = GeckoForConditionalGeneration.from_pretrained(
59
+ model, config=config, torch_dtype=torch_dtype,
60
+ attn_implementation=attn_implementation, device_map=device_map)
61
  model.load_text_encoder(processor)
62
 
63
  @spaces.GPU