CaptionQuest / app.py
Frantz103's picture
Update app.py
3dbe742
raw
history blame
9.18 kB
from transformers import pipeline
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
import torch
from torchvision import transforms
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import textstat
import spacy
import re
# Initialize the processor and model for the large COCO model
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
detection_pipe = pipeline("object-detection", model="facebook/detr-resnet-50")
classification_pipe = pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14")
# Initialize the pipeline for the VIT model
vit_pipeline = pipeline(task="image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
# Move the COCO model to the device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def generate_text_and_caption(image):
# Define the preprocessing pipeline for the image
preprocess = transforms.Compose([
transforms.Resize((256, 256)), # Resize to 256x256, change this to match the required dimensions
transforms.CenterCrop(224), # Center crop to 224x224, change this to match the required dimensions
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet mean and std
])
# Apply the preprocessing pipeline to the image
preprocessed_image = preprocess(image).unsqueeze(0).to(device) # unsqueeze to add batch dimension
# For large COCO model
generated_ids = model.generate(pixel_values=preprocessed_image, max_length=20)
caption1 = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# For VIT model
vit_output = vit_pipeline(image)
caption2_info = vit_output[0] if vit_output else {"generated_text": "N/A"}
caption2 = caption2_info.get('generated_text', 'N/A')
return caption1, caption2
def get_unique_refined_labels(image):
original_output = detection_pipe(image)
filtered_output = [item for item in original_output if item['score'] >= 0.95]
unique_refined_labels = {}
for item in filtered_output:
box = item['box']
label = item['label']
xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
cropped_image = image.crop((xmin, ymin, xmax, ymax))
predictions = classification_pipe(cropped_image, candidate_labels=[label])
if predictions:
top_prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)[0]
top_label = top_prediction['label']
top_score = top_prediction['score']
if top_label not in unique_refined_labels or unique_refined_labels[top_label] < top_score:
unique_refined_labels[top_label] = top_score
return unique_refined_labels, original_output, filtered_output
# Load NLP model for entity extraction
nlp = spacy.load("en_core_web_sm")
def extract_main_words(text):
doc = nlp(text)
main_words = [token.lemma_ for token in doc if token.pos_ == 'NOUN']
return main_words
def get_topics(text):
# Vectorize the text
vectorizer = CountVectorizer()
text_vec = vectorizer.fit_transform([text])
# Fit LDA model to get topics
lda = LatentDirichletAllocation(n_components=1, random_state=0)
lda.fit(text_vec)
# Get the top words per topic (assuming one topic for simplicity)
feature_names = vectorizer.get_feature_names_out()
top_words = [feature_names[i] for i in lda.components_[0].argsort()[:-10 - 1:-1]]
return top_words
def check_readability(caption):
# Compute the Flesch Reading Ease score of the caption
reading_ease_score = textstat.flesch_reading_ease(caption)
return reading_ease_score
def compute_similarity(caption1, caption2):
vectorizer = TfidfVectorizer().fit_transform([caption1, caption2])
vectors = vectorizer.toarray()
cosine_sim = cosine_similarity(vectors)
# The similarity between the captions is the off-diagonal value of the cosine_sim matrix
similarity_score = cosine_sim[0, 1]
return similarity_score
# Cell 3
def evaluate_caption(image, caption1, caption2, unique_refined_labels):
# Scores initialization
score_caption1 = 0
score_caption2 = 0
# Initialize object presence scores
object_presence_score1 = 0
object_presence_score2 = 0
# Assume you have a function to extract main words
main_words_caption1 = extract_main_words(caption1)
main_words_caption2 = extract_main_words(caption2)
# Check for object presence using unique_refined_labels
object_presence_score1 += sum([1 for word in main_words_caption1 if word in unique_refined_labels.keys()])
object_presence_score2 += sum([1 for word in main_words_caption2 if word in unique_refined_labels.keys()])
# Entity Extraction
entities_caption1 = [ent.text for ent in nlp(caption1).ents]
entities_caption2 = [ent.text for ent in nlp(caption2).ents]
# Check for object presence using unique_refined_labels
score_caption1 += sum([1 for entity in entities_caption1 if entity in unique_refined_labels.keys()])
score_caption2 += sum([1 for entity in entities_caption2 if entity in unique_refined_labels.keys()])
# Topic Modeling
topics_caption1 = get_topics(caption1)
topics_caption2 = get_topics(caption2)
# Check for topic relevance using unique_refined_labels
score_caption1 += sum([1 for topic in topics_caption1 if topic in unique_refined_labels.keys()])
score_caption2 += sum([1 for topic in topics_caption2 if topic in unique_refined_labels.keys()])
# Implement custom rules
def custom_rules(caption):
score = 0
# Rule for starting with a capital letter
if not caption[0].isupper():
score -= 1
# Rule for ending with punctuation
if caption[-1] not in ['.', '!', '?']:
score -= 1
return score
# Custom rule scores
custom_score1 = custom_rules(caption1)
custom_score2 = custom_rules(caption2)
# Update scores based on custom rules
score_caption1 += custom_score1 # Note: if these were errors, you'd subtract instead
score_caption2 += custom_score2
# Check length
length_caption1 = len(caption1.split())
length_caption2 = len(caption2.split())
if length_caption1 < 3: # assuming a reasonable caption should have at least 3 words
score_caption1 -= 3 # arbitrary penalty
if length_caption2 < 3:
score_caption2 -= 3 # arbitrary penalty
#Define similarity threshold
similarity_score = compute_similarity(caption1, caption2)
similarity_threshold = 0.9 # Replace this with whatever you consider "close enough"
score_difference = abs(score_caption1 - score_caption2)
score_threshold = 2 # Replace this with whatever you consider "close enough"
if score_difference <= score_threshold:
if similarity_score > similarity_threshold:
readability_score_caption1 = check_readability(caption1)
readability_score_caption2 = check_readability(caption2)
return caption1 if readability_score_caption1 > readability_score_caption2 else caption2
else:
return caption1 if score_caption1 > score_caption2 else caption2
# Fallback return statement
return caption2 if score_caption2 > score_caption2 else caption1
# Define the post_process_caption function
def post_process_caption(caption):
# Remove [unusedX] tokens, where X is any number
cleaned_caption = re.sub(r'\[\s*unused\d+\s*\](, )? ?', '', caption)
return cleaned_caption
def process_image(image_path):
image = Image.open(image_path).convert("RGB")
caption1, caption2 = generate_text_and_caption(image)
unique_refined_labels, _, _ = get_unique_refined_labels(image)
# Update return values for caption1
caption1 = post_process_caption(caption1)
# evealuate the captions
better_caption = evaluate_caption(image, caption1, caption2, unique_refined_labels)
return caption1, caption2, better_caption
import gradio as gr
img_cap_ui = gr.Interface(
fn=process_image,
title="Image Captioning with Automactic Evaluation",
description="Caution: this is a research experiment for personal use, please review the captions before using.",
inputs=gr.inputs.Image(type="filepath",label="Add your image"),
outputs=[gr.Textbox(label="Caption from the git-coco model"),
gr.Textbox(label="Caption from the nlp-connect model"),
gr.Textbox(label="Suggested caption after automatic evaluation")],
article="The caption evaluation method use a simple voting scheme from outputs of 2 additional models. This is an expirement, please use your judgment/edit if you use the generated caption.",
theme=gr.themes.Soft()
)
img_cap_ui.launch()