from transformers import pipeline from PIL import Image import numpy as np import matplotlib.pyplot as plt from transformers import AutoProcessor, AutoModelForCausalLM, pipeline import torch from torchvision import transforms from sklearn.feature_extraction.text import CountVectorizer from sklearn.decomposition import LatentDirichletAllocation from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import textstat import spacy import re # Initialize the processor and model for the large COCO model processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") detection_pipe = pipeline("object-detection", model="facebook/detr-resnet-50") classification_pipe = pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14") # Initialize the pipeline for the VIT model vit_pipeline = pipeline(task="image-to-text", model="nlpconnect/vit-gpt2-image-captioning") # Move the COCO model to the device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def generate_text_and_caption(image): # Define the preprocessing pipeline for the image preprocess = transforms.Compose([ transforms.Resize((256, 256)), # Resize to 256x256, change this to match the required dimensions transforms.CenterCrop(224), # Center crop to 224x224, change this to match the required dimensions transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet mean and std ]) # Apply the preprocessing pipeline to the image preprocessed_image = preprocess(image).unsqueeze(0).to(device) # unsqueeze to add batch dimension # For large COCO model generated_ids = model.generate(pixel_values=preprocessed_image, max_length=20) caption1 = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # For VIT model vit_output = vit_pipeline(image) caption2_info = vit_output[0] if vit_output else {"generated_text": "N/A"} caption2 = caption2_info.get('generated_text', 'N/A') return caption1, caption2 def get_unique_refined_labels(image): original_output = detection_pipe(image) filtered_output = [item for item in original_output if item['score'] >= 0.95] unique_refined_labels = {} for item in filtered_output: box = item['box'] label = item['label'] xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] cropped_image = image.crop((xmin, ymin, xmax, ymax)) predictions = classification_pipe(cropped_image, candidate_labels=[label]) if predictions: top_prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)[0] top_label = top_prediction['label'] top_score = top_prediction['score'] if top_label not in unique_refined_labels or unique_refined_labels[top_label] < top_score: unique_refined_labels[top_label] = top_score return unique_refined_labels, original_output, filtered_output # Load NLP model for entity extraction nlp = spacy.load("en_core_web_sm") def extract_main_words(text): doc = nlp(text) main_words = [token.lemma_ for token in doc if token.pos_ == 'NOUN'] return main_words def get_topics(text): # Vectorize the text vectorizer = CountVectorizer() text_vec = vectorizer.fit_transform([text]) # Fit LDA model to get topics lda = LatentDirichletAllocation(n_components=1, random_state=0) lda.fit(text_vec) # Get the top words per topic (assuming one topic for simplicity) feature_names = vectorizer.get_feature_names_out() top_words = [feature_names[i] for i in lda.components_[0].argsort()[:-10 - 1:-1]] return top_words def check_readability(caption): # Compute the Flesch Reading Ease score of the caption reading_ease_score = textstat.flesch_reading_ease(caption) return reading_ease_score def compute_similarity(caption1, caption2): vectorizer = TfidfVectorizer().fit_transform([caption1, caption2]) vectors = vectorizer.toarray() cosine_sim = cosine_similarity(vectors) # The similarity between the captions is the off-diagonal value of the cosine_sim matrix similarity_score = cosine_sim[0, 1] return similarity_score # Cell 3 def evaluate_caption(image, caption1, caption2, unique_refined_labels): # Scores initialization score_caption1 = 0 score_caption2 = 0 # Initialize object presence scores object_presence_score1 = 0 object_presence_score2 = 0 # Assume you have a function to extract main words main_words_caption1 = extract_main_words(caption1) main_words_caption2 = extract_main_words(caption2) # Check for object presence using unique_refined_labels object_presence_score1 += sum([1 for word in main_words_caption1 if word in unique_refined_labels.keys()]) object_presence_score2 += sum([1 for word in main_words_caption2 if word in unique_refined_labels.keys()]) # Entity Extraction entities_caption1 = [ent.text for ent in nlp(caption1).ents] entities_caption2 = [ent.text for ent in nlp(caption2).ents] # Check for object presence using unique_refined_labels score_caption1 += sum([1 for entity in entities_caption1 if entity in unique_refined_labels.keys()]) score_caption2 += sum([1 for entity in entities_caption2 if entity in unique_refined_labels.keys()]) # Topic Modeling topics_caption1 = get_topics(caption1) topics_caption2 = get_topics(caption2) # Check for topic relevance using unique_refined_labels score_caption1 += sum([1 for topic in topics_caption1 if topic in unique_refined_labels.keys()]) score_caption2 += sum([1 for topic in topics_caption2 if topic in unique_refined_labels.keys()]) # Implement custom rules def custom_rules(caption): score = 0 # Rule for starting with a capital letter if not caption[0].isupper(): score -= 1 # Rule for ending with punctuation if caption[-1] not in ['.', '!', '?']: score -= 1 return score # Custom rule scores custom_score1 = custom_rules(caption1) custom_score2 = custom_rules(caption2) # Update scores based on custom rules score_caption1 += custom_score1 # Note: if these were errors, you'd subtract instead score_caption2 += custom_score2 # Check length length_caption1 = len(caption1.split()) length_caption2 = len(caption2.split()) if length_caption1 < 3: # assuming a reasonable caption should have at least 3 words score_caption1 -= 3 # arbitrary penalty if length_caption2 < 3: score_caption2 -= 3 # arbitrary penalty #Define similarity threshold similarity_score = compute_similarity(caption1, caption2) print(similarity_score) similarity_threshold = 0.9 # Replace this with whatever you consider "close enough" score_difference = abs(score_caption1 - score_caption2) score_threshold = 2 # Replace this with whatever you consider "close enough" if score_difference <= score_threshold: if similarity_score > similarity_threshold: readability_score_caption1 = check_readability(caption1) readability_score_caption2 = check_readability(caption2) return caption1 if readability_score_caption1 > readability_score_caption2 else caption2 else: return caption1 if score_caption1 > score_caption2 else caption2 # Fallback return statement return caption2 if score_caption2 > score_caption2 else caption1 # Define the post_process_caption function def post_process_caption(caption): # Remove [unusedX] tokens, where X is any number cleaned_caption = re.sub(r'\[\s*unused\d+\s*\](, )? ?', '', caption) return cleaned_caption def process_image(image_path): image = Image.open(image_path).convert("RGB") caption1, caption2 = generate_text_and_caption(image) unique_refined_labels, _, _ = get_unique_refined_labels(image) # Update return values for caption1 caption1 = post_process_caption(caption1) # evealuate the captions better_caption = evaluate_caption(image, caption1, caption2, unique_refined_labels) return caption1, caption2, better_caption import gradio as gr img_cap_ui = gr.Interface( fn=process_image, title="Image Captioning with Automactic Evaluation", description="Caution: this is a research experiment for personal use, please review the captions before using.", inputs=gr.inputs.Image(type="filepath",label="Add your image"), outputs=[gr.Textbox(label="Caption from the git-coco model"), gr.Textbox(label="Caption from the nlp-connect model"), gr.Textbox(label="Suggested caption after automatic evaluation")], article="The caption evaluation method use a simple voting scheme from outputs from 3 additional models. This is an expirement, please use your judgment to edit before using the generate caption.", theme=gr.themes.Soft() ) img_cap_ui.launch()