Data_Collection / app.py
GeorgeIbrahim's picture
updates
2d9f0ba
import gradio as gr
import os
import threading
import random
from datasets import load_dataset, Dataset, Features, Value, concatenate_datasets
from huggingface_hub import login
import json
import re
# reset = False
def check_word_count(caption):
# Check if the caption has 3 or more words
return gr.update(interactive=len(caption.split()) >= 3)
# Authenticate with Hugging Face
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
login(token=token)
else:
print("HUGGINGFACE_TOKEN environment variable not set.")
dataset_name = "GeorgeIbrahim/EGYCOCO" # Replace with your dataset name
with open('nearest_neighbors_with_captions.json', 'r') as f:
results = json.load(f)
# Load or create the dataset
try:
dataset = load_dataset(dataset_name, split="train")
dataset = dataset.filter(lambda example: example["image_id"] != "COCO_val2014_000000111367.jpg")
print("Loaded existing dataset:", dataset)
print("Dataset features:", dataset.features) # Check if 'split' is part of features
# Check if the 'split' column exists; if not, add it
if 'split' not in dataset.column_names:
# Define the 'split' values based on `image_id`
split_values = [
"dev" if example["image_id"] in results else "train"
for example in dataset
]
# Add 'split' column to the dataset
dataset = dataset.add_column("split", split_values)
print("Added 'split' column to dataset.")
else:
print("'split' column already exists.")
# Create a dictionary to keep track of the highest annotation count for each image
annotation_counts = {}
for example in dataset:
image_id = example["image_id"]
count = example["annotation_count"]
if image_id not in annotation_counts or count > annotation_counts[image_id]:
annotation_counts[image_id] = count
print("Annotation counts:", annotation_counts)
except Exception as e:
print(f"Error loading dataset: {e}")
# Create an empty dataset if it doesn't exist
features = Features({
'image_id': Value(dtype='string'),
'caption': Value(dtype='string'),
'annotation_count': Value(dtype='int32'),
'split': Value(dtype='string')
})
dataset = Dataset.from_dict({'image_id': [], 'caption': [], 'annotation_count': [], 'split': []}, features=features)
annotation_counts = {}
dataset.push_to_hub(dataset_name) # Push the empty dataset to Hugging Face
# Initialize or reset data as needed based on the `reset` flag
# if reset:
# # Clear the annotation counts
# annotation_counts = {}
# shown_counts = {} # If you are tracking shown counts separately for images
# # Optionally, clear or reinitialize the dataset
# features = Features({
# 'image_id': Value(dtype='string'),
# 'caption': Value(dtype='string'),
# 'annotation_count': Value(dtype='int32'),
# 'split': Value(dtype='string')
# })
# dataset = Dataset.from_dict({
# 'image_id': [],
# 'caption': [],
# 'annotation_count': [],
# 'split': []
# }, features=features)
# # Push the reset dataset to Hugging Face or perform other necessary actions
# dataset.push_to_hub(dataset_name)
# print("Data has been reset.")
image_folder = "images"
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
len_files = len(image_files)
lock = threading.Lock()
def get_caption_for_image_id(image_path):
"""
Retrieve the caption for a given image_id from the JSON data.
"""
# Extract the numeric part of the image ID
match = re.search(r'_(\d+)\.', image_path)
if match:
image_id = match.group(1).lstrip('0') # Remove leading zeros
print("Searching for image_id:", image_id) # Debugging line
# Check if image_id is a test image
if image_id in results:
print("Found caption in results:", results[image_id]["caption"]) # Debugging line
return results[image_id]["caption"]
# If image_id is not a test image, search in nearest neighbors
for test_image_data in results.values():
for neighbor in test_image_data["nearest_neighbors"]:
if neighbor["image_id"] == image_id:
print("Found caption in nearest neighbors:", neighbor["caption"]) # Debugging line
return neighbor["caption"]
# Return None if the image_id is not found
print("Caption not found for image_id:", image_id) # Debugging line
return None
# Function to get a random image that hasn’t been fully annotated
def get_next_image(session_data):
with lock:
# Available images filter
available_images = []
# Iterate over each image file to apply the filtering logic
for img in image_files:
# Match and extract the image_id from the filename
match = re.search(r'_(\d+)\.', img)
if match:
image_id_2 = match.group(1).lstrip('0') # Remove leading zeros
# Apply the filtering conditions
if (img not in annotation_counts or
(image_id_2 in results and annotation_counts.get(img, 0) < 2) or
(image_id_2 not in results and annotation_counts.get(img, 0) == 0)):
available_images.append(img)
# print("Available images:", available_images) # Debugging line
print(available_images)
print("Remaining images: ", len_files - len(available_images))
# random.shuffle(available_images)
# Check if the user already has an image
if session_data["current_image"] is None and available_images:
# Assign a new random image to the user
session_data["current_image"] = random.choice(available_images)
# print("Current image_id:", session_data["current_image"]) # Print the current image_id
return os.path.join(image_folder, session_data["current_image"]) if session_data["current_image"] else None
# Function to save the annotation to Hugging Face dataset and fetch the next image
def save_annotation(caption, session_data):
global dataset, annotation_counts # Declare global dataset and annotation_counts at the start of the function
if session_data["current_image"] is None:
return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="")
with lock:
image_id = session_data["current_image"]
match = re.search(r'_(\d+)\.', image_id)
image_2 = match.group(1).lstrip('0')
split = "dev" if image_2 in results else "train"
# Save caption or "skipped" based on user input
if caption.strip().lower() == "skip":
caption = "skipped"
# Get current annotation count
annotation_count = annotation_counts.get(image_id, 0)
# Add the new annotation as a new row to the dataset
new_data = Dataset.from_dict({
"image_id": [image_id],
"caption": [caption],
"annotation_count": [annotation_count + 1],
"split": [split]
}, features=Features({
'image_id': Value(dtype='string'),
'caption': Value(dtype='string'),
'annotation_count': Value(dtype='int32'),
'split': Value(dtype='string')
}))
# Update the annotation count in the dictionary
annotation_counts[image_id] = annotation_count + 1
# Concatenate with the existing dataset and push the updated dataset to Hugging Face
dataset = concatenate_datasets([dataset, new_data])
dataset = dataset.filter(lambda example: example['caption'].strip() != "")
dataset.push_to_hub(dataset_name)
print("Pushed updated dataset")
# # Clear user's current image if the validation image has been annotated twice
# if (split == "train" and annotation_count > 1) or (split == "dev" and annotation_count > 2):
session_data["current_image"] = None
# Fetch the next image
next_image = get_next_image(session_data)
if next_image:
next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve the caption for the new image
print("Next image_id:", os.path.basename(next_image)) # Debugging line
return gr.update(value=next_image), gr.update(value=""), gr.update(value=next_caption or "")
else:
return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="")
def initialize_interface(session_data):
next_image = get_next_image(session_data)
if next_image:
next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve caption for initial image
print("Initial image_id:", os.path.basename(next_image)) # Print the initial image_id
return gr.update(value=next_image), gr.update(value=next_caption or "")
else:
return gr.update(visible=False), gr.update(value="All images have been annotated!")
# Build the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Captioning Tool")
gr.Markdown("Please provide your caption in Egyptian Arabic 'Masri'")
session_data = gr.State({"current_image": None}) # Session-specific state
with gr.Row():
image = gr.Image()
caption = gr.Textbox(placeholder="Enter caption here...")
existing_caption = gr.Textbox(label="Existing Caption", interactive=False) # Display existing caption
submit = gr.Button("Submit", interactive=False) # Initially disabled
# Enable/disable the submit button based on word count
caption.change(fn=check_word_count, inputs=caption, outputs=submit)
# Define actions for buttons
submit.click(fn=save_annotation, inputs=[caption, session_data], outputs=[image, caption, existing_caption])
# Load initial image
demo.load(fn=initialize_interface, inputs=session_data, outputs=[image, existing_caption])
demo.launch(share=True)