Spaces:
Paused
Paused
from transformers import pipeline | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline | |
import torch | |
from torchvision import transforms | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.decomposition import LatentDirichletAllocation | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import textstat | |
import spacy | |
import re | |
# Initialize the processor and model for the large COCO model | |
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
detection_pipe = pipeline("object-detection", model="facebook/detr-resnet-50") | |
classification_pipe = pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14") | |
# Initialize the pipeline for the VIT model | |
vit_pipeline = pipeline(task="image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
# Move the COCO model to the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
def generate_text_and_caption(image): | |
# Define the preprocessing pipeline for the image | |
preprocess = transforms.Compose([ | |
transforms.Resize((256, 256)), # Resize to 256x256, change this to match the required dimensions | |
transforms.CenterCrop(224), # Center crop to 224x224, change this to match the required dimensions | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet mean and std | |
]) | |
# Apply the preprocessing pipeline to the image | |
preprocessed_image = preprocess(image).unsqueeze(0).to(device) # unsqueeze to add batch dimension | |
# For large COCO model | |
generated_ids = model.generate(pixel_values=preprocessed_image, max_length=20) | |
caption1 = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# For VIT model | |
vit_output = vit_pipeline(image) | |
caption2_info = vit_output[0] if vit_output else {"generated_text": "N/A"} | |
caption2 = caption2_info.get('generated_text', 'N/A') | |
return caption1, caption2 | |
def get_unique_refined_labels(image): | |
original_output = detection_pipe(image) | |
filtered_output = [item for item in original_output if item['score'] >= 0.95] | |
unique_refined_labels = {} | |
for item in filtered_output: | |
box = item['box'] | |
label = item['label'] | |
xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] | |
cropped_image = image.crop((xmin, ymin, xmax, ymax)) | |
predictions = classification_pipe(cropped_image, candidate_labels=[label]) | |
if predictions: | |
top_prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)[0] | |
top_label = top_prediction['label'] | |
top_score = top_prediction['score'] | |
if top_label not in unique_refined_labels or unique_refined_labels[top_label] < top_score: | |
unique_refined_labels[top_label] = top_score | |
return unique_refined_labels, original_output, filtered_output | |
# Load NLP model for entity extraction | |
nlp = spacy.load("en_core_web_sm") | |
def extract_main_words(text): | |
doc = nlp(text) | |
main_words = [token.lemma_ for token in doc if token.pos_ == 'NOUN'] | |
return main_words | |
def get_topics(text): | |
# Vectorize the text | |
vectorizer = CountVectorizer() | |
text_vec = vectorizer.fit_transform([text]) | |
# Fit LDA model to get topics | |
lda = LatentDirichletAllocation(n_components=1, random_state=0) | |
lda.fit(text_vec) | |
# Get the top words per topic (assuming one topic for simplicity) | |
feature_names = vectorizer.get_feature_names_out() | |
top_words = [feature_names[i] for i in lda.components_[0].argsort()[:-10 - 1:-1]] | |
return top_words | |
def check_readability(caption): | |
# Compute the Flesch Reading Ease score of the caption | |
reading_ease_score = textstat.flesch_reading_ease(caption) | |
return reading_ease_score | |
def compute_similarity(caption1, caption2): | |
vectorizer = TfidfVectorizer().fit_transform([caption1, caption2]) | |
vectors = vectorizer.toarray() | |
cosine_sim = cosine_similarity(vectors) | |
# The similarity between the captions is the off-diagonal value of the cosine_sim matrix | |
similarity_score = cosine_sim[0, 1] | |
return similarity_score | |
def evaluate_caption(image, caption1, caption2, unique_refined_labels): | |
# Scores initialization | |
score_caption1 = 0 | |
score_caption2 = 0 | |
# Initialize object presence scores | |
object_presence_score1 = 0 | |
object_presence_score2 = 0 | |
# Assume you have a function to extract main words | |
main_words_caption1 = extract_main_words(caption1) | |
main_words_caption2 = extract_main_words(caption2) | |
# Check for object presence using unique_refined_labels | |
object_presence_score1 += sum([1 for word in main_words_caption1 if word in unique_refined_labels.keys()]) | |
object_presence_score2 += sum([1 for word in main_words_caption2 if word in unique_refined_labels.keys()]) | |
# Entity Extraction | |
entities_caption1 = [ent.text for ent in nlp(caption1).ents] | |
entities_caption2 = [ent.text for ent in nlp(caption2).ents] | |
# Check for object presence using unique_refined_labels | |
score_caption1 += sum([1 for entity in entities_caption1 if entity in unique_refined_labels.keys()]) | |
score_caption2 += sum([1 for entity in entities_caption2 if entity in unique_refined_labels.keys()]) | |
# Topic Modeling | |
topics_caption1 = get_topics(caption1) | |
topics_caption2 = get_topics(caption2) | |
# Check for topic relevance using unique_refined_labels | |
score_caption1 += sum([1 for topic in topics_caption1 if topic in unique_refined_labels.keys()]) | |
score_caption2 += sum([1 for topic in topics_caption2 if topic in unique_refined_labels.keys()]) | |
# Implement custom rules | |
def custom_rules(caption): | |
score = 0 | |
# Rule for starting with a capital letter | |
if not caption[0].isupper(): | |
score -= 1 | |
# Rule for ending with punctuation | |
if caption[-1] not in ['.', '!', '?']: | |
score -= 1 | |
return score | |
# Custom rule scores | |
custom_score1 = custom_rules(caption1) | |
custom_score2 = custom_rules(caption2) | |
# Update scores based on custom rules | |
score_caption1 += custom_score1 # Note: if these were errors, you'd subtract instead | |
score_caption2 += custom_score2 | |
# Check length | |
length_caption1 = len(caption1.split()) | |
length_caption2 = len(caption2.split()) | |
if length_caption1 < 3: # assuming a reasonable caption should have at least 3 words | |
score_caption1 -= 3 # arbitrary penalty | |
if length_caption2 < 3: | |
score_caption2 -= 3 # arbitrary penalty | |
#Define similarity threshold | |
similarity_score = compute_similarity(caption1, caption2) | |
similarity_threshold = 0.9 # Replace this with whatever you consider "close enough" | |
score_difference = abs(score_caption1 - score_caption2) | |
score_threshold = 2 # Replace this with whatever you consider "close enough" | |
if score_difference <= score_threshold: | |
if similarity_score > similarity_threshold: | |
readability_score_caption1 = check_readability(caption1) | |
readability_score_caption2 = check_readability(caption2) | |
return caption1 if readability_score_caption1 > readability_score_caption2 else caption2 | |
else: | |
return caption1 if score_caption1 > score_caption2 else caption2 | |
# Fallback return statement | |
return caption2 if score_caption2 > score_caption2 else caption1 | |
# Define the post_process_caption function | |
def post_process_caption(caption): | |
# Remove [unusedX] tokens, where X is any number | |
cleaned_caption = re.sub(r'\[\s*unused\d+\s*\](, )? ?', '', caption) | |
return cleaned_caption | |
def process_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
caption1, caption2 = generate_text_and_caption(image) | |
unique_refined_labels, _, _ = get_unique_refined_labels(image) | |
# Update return values for caption1 | |
caption1 = post_process_caption(caption1) | |
# evealuate the captions | |
better_caption = evaluate_caption(image, caption1, caption2, unique_refined_labels) | |
return caption1, caption2, better_caption | |
import gradio as gr | |
img_cap_ui = gr.Interface( | |
fn=process_image, | |
title="Image Captioning with Automatic Evaluation", | |
description="Caution: this is a research experiment for personal use, please review the captions before using.", | |
inputs=gr.inputs.Image(type="filepath",label="Add your image"), | |
outputs=[gr.Textbox(label="Caption from the git-coco model", show_copy_button=True), | |
gr.Textbox(label="Caption from the nlp-connect model", show_copy_button=True), | |
gr.Textbox(label="Suggested caption after automatic evaluation", show_copy_button=True)], | |
examples=["image_31.jpg","image_41.jpg","image_48.jpg", "image_50.jpg"], | |
article="The caption evaluation method use a simple voting scheme from outputs of 2 additional models. This is an experiment, please make edit if you use the generated caption.", | |
theme=gr.themes.Soft() | |
) | |
img_cap_ui.launch() | |