Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import threading | |
import random | |
from datasets import load_dataset, Dataset, Features, Value, concatenate_datasets | |
from huggingface_hub import login | |
import json | |
import re | |
# reset = False | |
def check_word_count(caption): | |
# Check if the caption has 3 or more words | |
return gr.update(interactive=len(caption.split()) >= 3) | |
# Authenticate with Hugging Face | |
token = os.getenv("HUGGINGFACE_TOKEN") | |
if token: | |
login(token=token) | |
else: | |
print("HUGGINGFACE_TOKEN environment variable not set.") | |
dataset_name = "GeorgeIbrahim/EGYCOCO" # Replace with your dataset name | |
with open('nearest_neighbors_with_captions.json', 'r') as f: | |
results = json.load(f) | |
# Load or create the dataset | |
try: | |
dataset = load_dataset(dataset_name, split="train") | |
dataset = dataset.filter(lambda example: example["image_id"] != "COCO_val2014_000000111367.jpg") | |
print("Loaded existing dataset:", dataset) | |
print("Dataset features:", dataset.features) # Check if 'split' is part of features | |
# Check if the 'split' column exists; if not, add it | |
if 'split' not in dataset.column_names: | |
# Define the 'split' values based on `image_id` | |
split_values = [ | |
"dev" if example["image_id"] in results else "train" | |
for example in dataset | |
] | |
# Add 'split' column to the dataset | |
dataset = dataset.add_column("split", split_values) | |
print("Added 'split' column to dataset.") | |
else: | |
print("'split' column already exists.") | |
# Create a dictionary to keep track of the highest annotation count for each image | |
annotation_counts = {} | |
for example in dataset: | |
image_id = example["image_id"] | |
count = example["annotation_count"] | |
if image_id not in annotation_counts or count > annotation_counts[image_id]: | |
annotation_counts[image_id] = count | |
print("Annotation counts:", annotation_counts) | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
# Create an empty dataset if it doesn't exist | |
features = Features({ | |
'image_id': Value(dtype='string'), | |
'caption': Value(dtype='string'), | |
'annotation_count': Value(dtype='int32'), | |
'split': Value(dtype='string') | |
}) | |
dataset = Dataset.from_dict({'image_id': [], 'caption': [], 'annotation_count': [], 'split': []}, features=features) | |
annotation_counts = {} | |
dataset.push_to_hub(dataset_name) # Push the empty dataset to Hugging Face | |
# Initialize or reset data as needed based on the `reset` flag | |
# if reset: | |
# # Clear the annotation counts | |
# annotation_counts = {} | |
# shown_counts = {} # If you are tracking shown counts separately for images | |
# # Optionally, clear or reinitialize the dataset | |
# features = Features({ | |
# 'image_id': Value(dtype='string'), | |
# 'caption': Value(dtype='string'), | |
# 'annotation_count': Value(dtype='int32'), | |
# 'split': Value(dtype='string') | |
# }) | |
# dataset = Dataset.from_dict({ | |
# 'image_id': [], | |
# 'caption': [], | |
# 'annotation_count': [], | |
# 'split': [] | |
# }, features=features) | |
# # Push the reset dataset to Hugging Face or perform other necessary actions | |
# dataset.push_to_hub(dataset_name) | |
# print("Data has been reset.") | |
image_folder = "images" | |
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.png', '.jpg', '.jpeg'))] | |
len_files = len(image_files) | |
lock = threading.Lock() | |
def get_caption_for_image_id(image_path): | |
""" | |
Retrieve the caption for a given image_id from the JSON data. | |
""" | |
# Extract the numeric part of the image ID | |
match = re.search(r'_(\d+)\.', image_path) | |
if match: | |
image_id = match.group(1).lstrip('0') # Remove leading zeros | |
print("Searching for image_id:", image_id) # Debugging line | |
# Check if image_id is a test image | |
if image_id in results: | |
print("Found caption in results:", results[image_id]["caption"]) # Debugging line | |
return results[image_id]["caption"] | |
# If image_id is not a test image, search in nearest neighbors | |
for test_image_data in results.values(): | |
for neighbor in test_image_data["nearest_neighbors"]: | |
if neighbor["image_id"] == image_id: | |
print("Found caption in nearest neighbors:", neighbor["caption"]) # Debugging line | |
return neighbor["caption"] | |
# Return None if the image_id is not found | |
print("Caption not found for image_id:", image_id) # Debugging line | |
return None | |
# Function to get a random image that hasn’t been fully annotated | |
def get_next_image(session_data): | |
with lock: | |
# Available images filter | |
available_images = [] | |
# Iterate over each image file to apply the filtering logic | |
for img in image_files: | |
# Match and extract the image_id from the filename | |
match = re.search(r'_(\d+)\.', img) | |
if match: | |
image_id_2 = match.group(1).lstrip('0') # Remove leading zeros | |
# Apply the filtering conditions | |
if (img not in annotation_counts or | |
(image_id_2 in results and annotation_counts.get(img, 0) < 2) or | |
(image_id_2 not in results and annotation_counts.get(img, 0) == 0)): | |
available_images.append(img) | |
# print("Available images:", available_images) # Debugging line | |
print(available_images) | |
print("Remaining images: ", len_files - len(available_images)) | |
# random.shuffle(available_images) | |
# Check if the user already has an image | |
if session_data["current_image"] is None and available_images: | |
# Assign a new random image to the user | |
session_data["current_image"] = random.choice(available_images) | |
# print("Current image_id:", session_data["current_image"]) # Print the current image_id | |
return os.path.join(image_folder, session_data["current_image"]) if session_data["current_image"] else None | |
# Function to save the annotation to Hugging Face dataset and fetch the next image | |
def save_annotation(caption, session_data): | |
global dataset, annotation_counts # Declare global dataset and annotation_counts at the start of the function | |
if session_data["current_image"] is None: | |
return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="") | |
with lock: | |
image_id = session_data["current_image"] | |
match = re.search(r'_(\d+)\.', image_id) | |
image_2 = match.group(1).lstrip('0') | |
split = "dev" if image_2 in results else "train" | |
# Save caption or "skipped" based on user input | |
if caption.strip().lower() == "skip": | |
caption = "skipped" | |
# Get current annotation count | |
annotation_count = annotation_counts.get(image_id, 0) | |
# Add the new annotation as a new row to the dataset | |
new_data = Dataset.from_dict({ | |
"image_id": [image_id], | |
"caption": [caption], | |
"annotation_count": [annotation_count + 1], | |
"split": [split] | |
}, features=Features({ | |
'image_id': Value(dtype='string'), | |
'caption': Value(dtype='string'), | |
'annotation_count': Value(dtype='int32'), | |
'split': Value(dtype='string') | |
})) | |
# Update the annotation count in the dictionary | |
annotation_counts[image_id] = annotation_count + 1 | |
# Concatenate with the existing dataset and push the updated dataset to Hugging Face | |
dataset = concatenate_datasets([dataset, new_data]) | |
dataset = dataset.filter(lambda example: example['caption'].strip() != "") | |
dataset.push_to_hub(dataset_name) | |
print("Pushed updated dataset") | |
# # Clear user's current image if the validation image has been annotated twice | |
# if (split == "train" and annotation_count > 1) or (split == "dev" and annotation_count > 2): | |
session_data["current_image"] = None | |
# Fetch the next image | |
next_image = get_next_image(session_data) | |
if next_image: | |
next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve the caption for the new image | |
print("Next image_id:", os.path.basename(next_image)) # Debugging line | |
return gr.update(value=next_image), gr.update(value=""), gr.update(value=next_caption or "") | |
else: | |
return gr.update(visible=False), gr.update(value="All images have been annotated!"), gr.update(value="") | |
def initialize_interface(session_data): | |
next_image = get_next_image(session_data) | |
if next_image: | |
next_caption = get_caption_for_image_id(os.path.basename(next_image)) # Retrieve caption for initial image | |
print("Initial image_id:", os.path.basename(next_image)) # Print the initial image_id | |
return gr.update(value=next_image), gr.update(value=next_caption or "") | |
else: | |
return gr.update(visible=False), gr.update(value="All images have been annotated!") | |
# Build the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Captioning Tool") | |
gr.Markdown("Please provide your caption in Egyptian Arabic 'Masri'") | |
session_data = gr.State({"current_image": None}) # Session-specific state | |
with gr.Row(): | |
image = gr.Image() | |
caption = gr.Textbox(placeholder="Enter caption here...") | |
existing_caption = gr.Textbox(label="Existing Caption", interactive=False) # Display existing caption | |
submit = gr.Button("Submit", interactive=False) # Initially disabled | |
# Enable/disable the submit button based on word count | |
caption.change(fn=check_word_count, inputs=caption, outputs=submit) | |
# Define actions for buttons | |
submit.click(fn=save_annotation, inputs=[caption, session_data], outputs=[image, caption, existing_caption]) | |
# Load initial image | |
demo.load(fn=initialize_interface, inputs=session_data, outputs=[image, existing_caption]) | |
demo.launch(share=True) |