SRDdev commited on
Commit
fd823f5
1 Parent(s): 53313b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -3,12 +3,13 @@ import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
 
6
  encoder_checkpoint = "google/vit-base-patch16-224-in21k"
7
  decoder_checkpoint = "gpt2"
8
  model_checkpoint = "gagan3012/ViTGPT2I2A"
9
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
10
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
11
- model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
12
 
13
 
14
  def predict(image,max_length=64, num_beams=4):
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
+
7
  encoder_checkpoint = "google/vit-base-patch16-224-in21k"
8
  decoder_checkpoint = "gpt2"
9
  model_checkpoint = "gagan3012/ViTGPT2I2A"
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to('cpu')
13
 
14
 
15
  def predict(image,max_length=64, num_beams=4):