Spaces:
Running
Running
import requests | |
import gradio as gr | |
import torch | |
from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM | |
from transformers.models.auto.configuration_auto import AutoConfig | |
from src.vision_encoder_decoder import SmallCap, SmallCapConfig | |
from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel | |
from src.utils import prep_strings, postprocess_preds | |
import json | |
from src.retrieve_caps import * | |
from PIL import Image | |
from torchvision import transforms | |
from src.opt import ThisOPTConfig, ThisOPTForCausalLM | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load feature extractor | |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
# load and configure tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") | |
tokenizer.pad_token = '!' | |
tokenizer.eos_token = '.' | |
# load model | |
# AutoConfig.register("this_gpt2", ThisGPT2Config) | |
# AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel) | |
# AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel) | |
# AutoConfig.register("smallcap", SmallCapConfig) | |
# AutoModel.register(SmallCapConfig, SmallCap) | |
# model = AutoModel.from_pretrained("Yova/SmallCap7M") | |
AutoConfig.register("this_opt", ThisOPTConfig) | |
AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM) | |
AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM) | |
AutoConfig.register("smallcap", SmallCapConfig) | |
AutoModel.register(SmallCapConfig, SmallCap) | |
model = AutoModel.from_pretrained("Yova/SmallCapOPT7M") | |
model= model.to(device) | |
template = open('src/template.txt').read().strip() + ' ' | |
# precompute captions for retrieval | |
captions = json.load(open('coco_index_captions.json')) | |
retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device) | |
retrieval_index = faiss.read_index('coco_index') | |
#res = faiss.StandardGpuResources() | |
#retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index) | |
# Download human-readable labels for ImageNet. | |
response = requests.get("https://git.io/JJkYN") | |
labels = response.text.split("\n") | |
def retrieve_caps(image_embedding, index, k=4): | |
xq = image_embedding.astype(np.float32) | |
faiss.normalize_L2(xq) | |
D, I = index.search(xq, k) | |
return I | |
def classify_image(image): | |
inp = transforms.ToTensor()(image) | |
pixel_values_retrieval = feature_extractor_retrieval(image).to(device) | |
with torch.no_grad(): | |
image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy() | |
nns = retrieve_caps(image_embedding, retrieval_index)[0] | |
caps = [captions[i] for i in nns][:4] | |
# prepare prompt | |
decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True) | |
# generate caption | |
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
pred = model.generate(pixel_values.to(device), | |
decoder_input_ids=torch.tensor([decoder_input_ids]).to(device), | |
max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0, | |
min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id) | |
#inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) | |
#prediction = inception_net.predict(inp).flatten() | |
retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps) | |
#return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) | |
return str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) + "\n\n\n"+ retrieved_caps | |
image = gr.Image(type="pil") | |
textbox = gr.Textbox(placeholder="Generated caption and retrieved captions...", lines=4) | |
title = "SmallCap Demo" | |
gr.Interface( | |
fn=classify_image, inputs=image, outputs=textbox, title=title | |
).launch() |