SRDdev commited on
Commit
53313b7
1 Parent(s): 0c25281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -3,17 +3,14 @@ import re
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
-
7
- device = 'cpu'
8
- encoder_checkpoint = 'google/vit-base-patch16-224-in21k'
9
- decoder_checkpoint = 'distilgpt2'
10
- model_checkpoint = '"gagan3012/ViTGPT2_vizwiz"'
11
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
12
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
13
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
14
 
15
 
16
-
17
  def predict(image,max_length=64, num_beams=4):
18
  image = image.convert('RGB')
19
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
+ encoder_checkpoint = "google/vit-base-patch16-224-in21k"
7
+ decoder_checkpoint = "gpt2"
8
+ model_checkpoint = "gagan3012/ViTGPT2I2A"
 
 
9
  feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
10
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
11
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
12
 
13
 
 
14
  def predict(image,max_length=64, num_beams=4):
15
  image = image.convert('RGB')
16
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)