Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from PIL import Image | |
import requests | |
from transformers import ViTFeatureExtractor | |
feature_extractor = ViTFeatureExtractor() | |
# or, to load one that corresponds to a checkpoint on the hub: | |
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
from transformers import VisionEncoderDecoderModel | |
# initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized | |
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( | |
"google/vit-base-patch16-224-in21k", "bert-base-uncased" | |
) | |
# saving model after fine-tuning | |
model.save_pretrained("./vit-bert") | |
# load fine-tuned model | |
model = VisionEncoderDecoderModel.from_pretrained("./vit-bert") | |
##################### | |
from transformers import AutoTokenizer | |
repo_name = "ydshieh/vit-gpt2-coco-en" | |
feature_extractor = ViTFeatureExtractor.from_pretrained(repo_name) | |
tokenizer = AutoTokenizer.from_pretrained(repo_name) | |
model = VisionEncoderDecoderModel.from_pretrained(repo_name) | |
def get_quote(image): | |
############## | |
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values | |
# autoregressively generate text (using beam search or other decoding strategy) | |
generated_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True) | |
################ | |
# decode into text | |
preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True) | |
preds = [pred.strip() for pred in preds] | |
return preds | |
#1: Text to Speech | |
title = "Sentence, listing all the items present in the image file" | |
demo = gr.Interface(fn=get_quote, inputs=gr.inputs.Image(type="pil"), outputs=['text'],title = title, description = "Upload an image file and get text from it" ,cache_examples=False, enable_queue=True).launch() | |
if __name__ == "__main__": | |
demo.launch(debug=True, cache_examples=True) |