IAB_VIDEO_AD_CLASSIFIER / image_caption.py
arjunanand13's picture
Update image_caption.py
37ce9d7 verified
raw
history blame
No virus
3.68 kB
import argparse
from pathlib import Path
import os
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import io
import google.generativeai as genai
class Caption:
def __init__(self):
self.api_key = 'AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA'
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(model_name="gemini-pro-vision")
# self.model = VisionEncoderDecoderModel.from_pretrained(
# "nlpconnect/vit-gpt2-image-captioning"
# )
# self.feature_extractor = ViTImageProcessor.from_pretrained(
# "nlpconnect/vit-gpt2-image-captioning"
# )
# self.tokenizer = AutoTokenizer.from_pretrained(
# "nlpconnect/vit-gpt2-image-captioning"
# )
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model.to(self.device)
# self.max_length = 16
# self.num_beams = 4
# self.gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}
def predict_step(self,image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def predict_from_memory(self, image_buffers):
images = []
for image_buffer in image_buffers:
# Ensure the buffer is positioned at the start
if isinstance(image_buffer, io.BytesIO):
image_buffer.seek(0)
try:
i_image = Image.open(image_buffer)
if i_image.mode != "RGB":
i_image = i_image.convert("RGB")
images.append(i_image)
except Exception as e:
print(f"Failed to process image buffer: {str(e)}")
continue
return self.process_images(images)
def process_images(self, images):
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(self.device)
output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def predict_image_caption_gemini(self,img):
prompt = "Describe the main focus of this image in detail."
response = self.model.generate_content([prompt, img], stream=True)
response.resolve()
print("Derived data",response.text)
return response.text
def get_args(self):
parser = argparse.ArgumentParser()
parser.add_argument( "-i",
"--input_img_paths",
type=str,
default="farmer.jpg",
help="img for caption")
args = parser.parse_args()
return args
if __name__ == "__main__":
model = Caption()
args = model.get_args()
image_paths = []
image_paths.append(args.input_img_paths)
print(model.predict_step(image_paths))