SRDdev commited on
Commit
56b9e35
1 Parent(s): fd823f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -3,13 +3,13 @@ import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
-
7
  encoder_checkpoint = "google/vit-base-patch16-224-in21k"
8
  decoder_checkpoint = "gpt2"
9
  model_checkpoint = "gagan3012/ViTGPT2I2A"
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
- model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to('cpu')
13
 
14
 
15
  def predict(image,max_length=64, num_beams=4):
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
+ device='cpu'
7
  encoder_checkpoint = "google/vit-base-patch16-224-in21k"
8
  decoder_checkpoint = "gpt2"
9
  model_checkpoint = "gagan3012/ViTGPT2I2A"
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
 
14
 
15
  def predict(image,max_length=64, num_beams=4):