import gradio as gr import os import threading import random from datasets import load_dataset, Dataset, Features, Value, concatenate_datasets from huggingface_hub import login import json import re # reset = False def check_word_count(caption): # Check if the caption has 3 or more words return gr.update(interactive=len(caption.split()) >= 3) # Authenticate with Hugging Face token = os.getenv("HUGGINGFACE_TOKEN") if token: login(token=token) else: print("HUGGINGFACE_TOKEN environment variable not set.") dataset_name = "GeorgeIbrahim/EGYCOCO" # Replace with your dataset name with open('nearest_neighbors_with_captions.json', 'r') as f: results = json.load(f) # Load or create the dataset try: dataset = load_dataset(dataset_name, split="train") dataset = dataset.filter(lambda example: example["image_id"] != "COCO_val2014_000000111367.jpg") print("Loaded existing dataset:", dataset) print("Dataset features:", dataset.features) # Check if 'split' is part of features # Check if the 'split' column exists; if not, add it if 'split' not in dataset.column_names: # Define the 'split' values based on `image_id` split_values = [ "dev" if example["image_id"] in results else "train" for example in dataset ] # Add 'split' column to the dataset dataset = dataset.add_column("split", split_values) print("Added 'split' column to dataset.") else: print("'split' column already exists.") # Create a dictionary to keep track of the highest annotation count for each image annotation_counts = {} for example in dataset: image_id = example["image_id"] count = example["annotation_count"] if image_id not in annotation_counts or count > annotation_counts[image_id]: annotation_counts[image_id] = count print("Annotation counts:", annotation_counts) except Exception as e: print(f"Error loading dataset: {e}") # Create an empty dataset if it doesn't exist features = Features({ 'image_id': Value(dtype='string'), 'caption': Value(dtype='string'), 'annotation_count': Value(dtype='int32'), 'split': Value(dtype='string') }) dataset = Dataset.from_dict({'image_id': [], 'caption': [], 'annotation_count': [], 'split': []}, features=features) annotation_counts = {} dataset.push_to_hub(dataset_name) # Push the empty dataset to Hugging Face # Initialize or reset data as needed based on the `reset` flag # if reset: # # Clear the annotation counts # annotation_counts = {} # shown_counts = {} # If you are tracking shown counts separately for images # # Optionally, clear or reinitialize the dataset # features = Features({ # 'image_id': Value(dtype='string'), # 'caption': Value(dtype='string'), # 'annotation_count': Value(dtype='int32'), # 'split': Value(dtype='string') # }) # dataset = Dataset.from_dict({ # 'image_id': [], # 'caption': [], # 'annotation_count': [], # 'split': [] # }, features=features) # # Push the reset dataset to Hugging Face or perform other necessary actions # dataset.push_to_hub(dataset_name) # print("Data has been reset.") image_folder = "images" image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))] len_files = len(image_files) lock = threading.Lock() def get_caption_for_image_id(image_path): """ Retrieve the caption for a given image_id from the JSON data. """ # Extract the numeric part of the image ID match = re.search(r'_(\d+)\.', image_path) if match: image_id = match.group(1).lstrip('0') # Remove leading zeros print("Searching for image_id:", image_id) # Debugging line # Check if image_id is a test image if image_id in results: print("Found caption in results:", results[image_id]["caption"]) # Debugging line return results[image_id]["caption"] # If image_id is not a test image, search in nearest neighbors for test_image_data in results.values(): for neighbor in test_image_data["nearest_neighbors"]: if neighbor["image_id"] == image_id: print("Found caption in nearest neighbors:", neighbor["caption"]) # Debugging line return neighbor["caption"] # Return None if the image_id is not found print("Caption not found for image_id:", image_id) # Debugging line return None # Function to get a random image that hasn’t been fully annotated def get_next_image(session_data): with lock: # Available images filter available_images = [] # Iterate over each image file to apply the filtering logic for img in image_files: # Match and extract the image_id from the filename match = re.search(r'_(\d+)\.', img) if match: image_id_2 = match.group(1).lstrip('0') # Remove leading zeros # Apply the filtering conditions if (img not in annotation_counts or (image_id_2 in results and annotation_counts.get(img, 0) < 2) or (image_id_2 not in results and annotation_counts.get(img, 0) == 0)): available_images.append(img) # print("Available images:", available_images) # Debugging line print(available_images) print("Remaining images: ", len_files - len(available_images)) # random.shuffle(available_images) # Check if the user already has an image if session_data["current_image"] is None and available_images: # Assign a new random image to the user session_data["current_image"] = random.choice(available_images) # print("Current image_id:", session_data["current_image"]) # Print the current image_id return os.path.join(image_folder, session_data["current_image"]) if session_data["current_image"] else None # Function to save the annotation to Hugging Face dataset and fetch the next image def save_annotation(caption, session_data): global dataset, annotation_counts # Declare global dataset and annotation_counts at the start of the function if session_data["current_image"] is None: return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="") with lock: image_id = session_data["current_image"] match = re.search(r'_(\d+)\.', image_id) image_2 = match.group(1).lstrip('0') split = "dev" if image_2 in results else "train" # Save caption or "skipped" based on user input if caption.strip().lower() == "skip": caption = "skipped" # Get current annotation count annotation_count = annotation_counts.get(image_id, 0) # Add the new annotation as a new row to the dataset new_data = Dataset.from_dict({ "image_id": [image_id], "caption": [caption], "annotation_count": [annotation_count + 1], "split": [split] }, features=Features({ 'image_id': Value(dtype='string'), 'caption': Value(dtype='string'), 'annotation_count': Value(dtype='int32'), 'split': Value(dtype='string') })) # Update the annotation count in the dictionary annotation_counts[image_id] = annotation_count + 1 # Concatenate with the existing dataset and push the updated dataset to Hugging Face dataset = concatenate_datasets([dataset, new_data]) dataset = dataset.filter(lambda example: example['caption'].strip() != "") dataset.push_to_hub(dataset_name) print("Pushed updated dataset") # # Clear user's current image if the validation image has been annotated twice # if (split == "train" and annotation_count > 1) or (split == "dev" and annotation_count > 2): session_data["current_image"] = None # Fetch the next image next_image = get_next_image(session_data) if next_image: next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve the caption for the new image print("Next image_id:", os.path.basename(next_image)) # Debugging line return gr.update(value=next_image), gr.update(value=""), gr.update(value=next_caption or "") else: return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="") def initialize_interface(session_data): next_image = get_next_image(session_data) if next_image: next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve caption for initial image print("Initial image_id:", os.path.basename(next_image)) # Print the initial image_id return gr.update(value=next_image), gr.update(value=next_caption or "") else: return gr.update(visible=False), gr.update(value="All images have been annotated!") # Build the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Captioning Tool") gr.Markdown("Please provide your caption in Egyptian Arabic 'Masri'") session_data = gr.State({"current_image": None}) # Session-specific state with gr.Row(): image = gr.Image() caption = gr.Textbox(placeholder="Enter caption here...") existing_caption = gr.Textbox(label="Existing Caption", interactive=False) # Display existing caption submit = gr.Button("Submit", interactive=False) # Initially disabled # Enable/disable the submit button based on word count caption.change(fn=check_word_count, inputs=caption, outputs=submit) # Define actions for buttons submit.click(fn=save_annotation, inputs=[caption, session_data], outputs=[image, caption, existing_caption]) # Load initial image demo.load(fn=initialize_interface, inputs=session_data, outputs=[image, existing_caption]) demo.launch(share=True)