from transformers import pipeline from PIL import Image import numpy as np import matplotlib.pyplot as plt import spacy from transformers import AutoProcessor, AutoModelForCausalLM, pipeline import torch from torchvision import transforms from sklearn.feature_extraction.text import CountVectorizer from sklearn.decomposition import LatentDirichletAllocation from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import textstat import spacy import re # Initialize the processor and model for the large COCO model processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") detection_pipe = pipeline("object-detection", model="facebook/detr-resnet-50") classification_pipe = pipeline("zero-shot-image-classification", model="openai/clip-vit-large-patch14") # Initialize the pipeline for the VIT model vit_pipeline = pipeline(task="image-to-text", model="nlpconnect/vit-gpt2-image-captioning") # Move the COCO model to the device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def generate_text_and_caption(image): # Define the preprocessing pipeline for the image preprocess = transforms.Compose([ transforms.Resize((256, 256)), # Resize to 256x256, change this to match the required dimensions transforms.CenterCrop(224), # Center crop to 224x224, change this to match the required dimensions transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize with ImageNet mean and std ]) # Apply the preprocessing pipeline to the image preprocessed_image = preprocess(image).unsqueeze(0).to(device) # unsqueeze to add batch dimension # For large COCO model generated_ids = model.generate(pixel_values=preprocessed_image, max_length=20) caption1 = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # For VIT model vit_output = vit_pipeline(image) caption2_info = vit_output[0] if vit_output else {"generated_text": "N/A"} caption2 = caption2_info.get('generated_text', 'N/A') return caption1, caption2 def get_unique_refined_labels(image): original_output = detection_pipe(image) filtered_output = [item for item in original_output if item['score'] >= 0.95] unique_refined_labels = {} for item in filtered_output: box = item['box'] label = item['label'] xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] cropped_image = image.crop((xmin, ymin, xmax, ymax)) predictions = classification_pipe(cropped_image, candidate_labels=[label]) if predictions: top_prediction = sorted(predictions, key=lambda x: x['score'], reverse=True)[0] top_label = top_prediction['label'] top_score = top_prediction['score'] if top_label not in unique_refined_labels or unique_refined_labels[top_label] < top_score: unique_refined_labels[top_label] = top_score return unique_refined_labels, original_output, filtered_output # Load NLP model for entity extraction nlp = spacy.load("en_core_web_sm") def extract_main_words(text): doc = nlp(text) main_words = [token.lemma_ for token in doc if token.pos_ == 'NOUN'] return main_words def get_topics(text): # Vectorize the text vectorizer = CountVectorizer() text_vec = vectorizer.fit_transform([text]) # Fit LDA model to get topics lda = LatentDirichletAllocation(n_components=1, random_state=0) lda.fit(text_vec) # Get the top words per topic (assuming one topic for simplicity) feature_names = vectorizer.get_feature_names_out() top_words = [feature_names[i] for i in lda.components_[0].argsort()[:-10 - 1:-1]] return top_words def check_readability(caption): # Compute the Flesch Reading Ease score of the caption reading_ease_score = textstat.flesch_reading_ease(caption) return reading_ease_score def compute_similarity(caption1, caption2): vectorizer = TfidfVectorizer().fit_transform([caption1, caption2]) vectors = vectorizer.toarray() cosine_sim = cosine_similarity(vectors) # The similarity between the captions is the off-diagonal value of the cosine_sim matrix similarity_score = cosine_sim[0, 1] return similarity_score # Cell 3 def evaluate_caption(image, caption1, caption2, unique_refined_labels): # Scores initialization score_caption1 = 0 score_caption2 = 0 # Initialize object presence scores object_presence_score1 = 0 object_presence_score2 = 0 # Assume you have a function to extract main words main_words_caption1 = extract_main_words(caption1) main_words_caption2 = extract_main_words(caption2) # Check for object presence using unique_refined_labels object_presence_score1 += sum([1 for word in main_words_caption1 if word in unique_refined_labels.keys()]) object_presence_score2 += sum([1 for word in main_words_caption2 if word in unique_refined_labels.keys()]) # Entity Extraction entities_caption1 = [ent.text for ent in nlp(caption1).ents] entities_caption2 = [ent.text for ent in nlp(caption2).ents] # Check for object presence using unique_refined_labels score_caption1 += sum([1 for entity in entities_caption1 if entity in unique_refined_labels.keys()]) score_caption2 += sum([1 for entity in entities_caption2 if entity in unique_refined_labels.keys()]) # Topic Modeling topics_caption1 = get_topics(caption1) topics_caption2 = get_topics(caption2) # Check for topic relevance using unique_refined_labels score_caption1 += sum([1 for topic in topics_caption1 if topic in unique_refined_labels.keys()]) score_caption2 += sum([1 for topic in topics_caption2 if topic in unique_refined_labels.keys()]) # Implement custom rules def custom_rules(caption): score = 0 # Rule for starting with a capital letter if not caption[0].isupper(): score -= 1 # Rule for ending with punctuation if caption[-1] not in ['.', '!', '?']: score -= 1 return score # Custom rule scores custom_score1 = custom_rules(caption1) custom_score2 = custom_rules(caption2) # Update scores based on custom rules score_caption1 += custom_score1 # Note: if these were errors, you'd subtract instead score_caption2 += custom_score2 # Check length length_caption1 = len(caption1.split()) length_caption2 = len(caption2.split()) if length_caption1 < 3: # assuming a reasonable caption should have at least 3 words score_caption1 -= 3 # arbitrary penalty if length_caption2 < 3: score_caption2 -= 3 # arbitrary penalty #Define similarity threshold similarity_score = compute_similarity(caption1, caption2) print(similarity_score) similarity_threshold = 0.9 # Replace this with whatever you consider "close enough" score_difference = abs(score_caption1 - score_caption2) score_threshold = 2 # Replace this with whatever you consider "close enough" if score_difference <= score_threshold: if similarity_score > similarity_threshold: readability_score_caption1 = check_readability(caption1) readability_score_caption2 = check_readability(caption2) return caption1 if readability_score_caption1 > readability_score_caption2 else caption2 else: return caption1 if score_caption1 > score_caption2 else caption2 # Fallback return statement return caption2 if score_caption2 > score_caption2 else caption1 # Define the post_process_caption function def post_process_caption(caption): # Remove [unusedX] tokens, where X is any number cleaned_caption = re.sub(r'\[\s*unused\d+\s*\](, )? ?', '', caption) return cleaned_caption def process_image(image_path): image = Image.open(image_path).convert("RGB") caption1, caption2 = generate_text_and_caption(image) unique_refined_labels, _, _ = get_unique_refined_labels(image) # Update return values for caption1 caption1 = post_process_caption(caption1) # evealuate the captions better_caption = evaluate_caption(image, caption1, caption2, unique_refined_labels) return caption1, caption2, better_caption import gradio as gr img_cap_ui = gr.Interface( fn=process_image, title="Image Captioning with Automactic Evaluation", description="Caution: this is a research experiment for personal use, please review the captions before using.", inputs=gr.inputs.Image(type="filepath",label="Add your image"), outputs=[gr.Textbox(label="Caption from the git-coco model"), gr.Textbox(label="Caption from the nlp-connect model"), gr.Textbox(label="Suggested caption after automatic evaluation")], article="The caption evaluation method use a simple voting scheme from outputs from 3 additional models. This is an expirement, please use your judgment to edit before using the generate caption.", theme=gr.themes.Soft() ) img_cap_ui.launch(debug=True)