Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel | |
import zipfile | |
import os | |
import csv | |
from PIL import Image | |
device = 'cpu' | |
model_name="NourFakih/Vit-GPT2-COCO2017Flickr-40k-05" | |
# Load the pretrained model, feature extractor, and tokenizer | |
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device) | |
feature_extractor = ViTImageProcessor.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def predict(image, max_length=64, num_beams=4): | |
# Process the input image | |
image = image.convert('RGB') | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) | |
# Generate the caption | |
caption_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)[0] | |
# Decode and clean the generated caption | |
caption = tokenizer.decode(caption_ids, skip_special_tokens=True) | |
return caption | |
def process_images(image_files): | |
captions = [] | |
for image_file in image_files: | |
try: | |
# Open and verify the image | |
with Image.open(image_file) as img: | |
caption = predict(img) | |
captions.append((os.path.basename(image_file), caption)) | |
except Exception as e: | |
print(f"Skipping file {image_file}: {e}") | |
# Save the results to a CSV file | |
csv_file_path = 'image_captions.csv' | |
with open(csv_file_path, mode='w', newline='') as file: | |
writer = csv.writer(file) | |
writer.writerow(['Image', 'Caption']) | |
writer.writerows(captions) | |
return csv_file_path | |
def process_zip_files(zip_file_paths): | |
# Create a directory to extract images | |
extract_dir = 'extracted_images' | |
os.makedirs(extract_dir, exist_ok=True) | |
captions = [] | |
for zip_file_path in zip_file_paths: | |
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: | |
zip_ref.extractall(extract_dir) | |
# Verify extracted files and process images | |
for root, dirs, files in os.walk(extract_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
try: | |
# Open and verify the image | |
with Image.open(file_path) as img: | |
caption = predict(img) | |
captions.append((file, caption)) | |
except Exception as e: | |
print(f"Skipping file {file}: {e}") | |
# Save the results to a CSV file | |
csv_file_path = 'zip_image_captions.csv' | |
with open(csv_file_path, mode='w', newline='') as file: | |
writer = csv.writer(file) | |
writer.writerow(['Image Name', 'Caption']) | |
writer.writerows(captions) | |
return csv_file_path | |
def gr_process(zip_files, image_files): | |
if not zip_files and not image_files: | |
raise ValueError("At least one of zip_files or image_files must be provided.") | |
elif zip_files: | |
zip_file_paths = [zip_file.name for zip_file in zip_files] | |
return process_zip_files(zip_file_paths) | |
elif image_files: | |
image_file_paths = [image_file.name for image_file in image_files] | |
return process_images(image_file_paths) | |
def combine_csv_files(file1, file2, output_file='combined_captions.csv'): | |
with open(output_file, mode='w', newline='') as outfile: | |
writer = csv.writer(outfile) | |
writer.writerow(['Image Name', 'Caption']) | |
for file in [file1, file2]: | |
if os.path.exists(file): | |
with open(file, mode='r') as infile: | |
reader = csv.reader(infile) | |
next(reader) # Skip header row | |
for row in reader: | |
writer.writerow(row) | |
return output_file | |
css = ''' | |
h1#title { | |
text-align: center; | |
} | |
h3#header { | |
text-align: center; | |
} | |
img#overview { | |
max-width: 800px; | |
max-height: 600px; | |
} | |
img#style-image { | |
max-width: 1000px; | |
max-height: 600px; | |
} | |
.gr-image { | |
max-width: 150px; /* Set a small box for the image */ | |
max-height: 150px; | |
} | |
''' | |
demo = gr.Blocks(css=css) | |
with demo: | |
gr.Markdown('''<h1 id="title">Image Caption 🖼️</h1>''') | |
gr.Markdown('''Made by : No. Fa.''') | |
zip_files = gr.State([]) | |
image_files = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
new_zip_files = gr.File(label="Upload Zip Files", type="filepath", file_count="multiple") | |
generate_zip_captions_btn = gr.Button("Generate Zip Captions") | |
new_image_files = gr.File(label="Upload Images", type="filepath", file_count="multiple") | |
generate_image_captions_btn = gr.Button("Generate Image Captions") | |
with gr.Column(scale=3): | |
output_zip_file = gr.File(label="Download Zip Captions") | |
output_image_file = gr.File(label="Download Image Captions") | |
combined_file = gr.File(label="Download Combined Captions") | |
combine_files_btn = gr.Button("Combine CSV Files") | |
generate_zip_captions_btn.click(fn=gr_process, inputs=new_zip_files, outputs=output_zip_file) | |
generate_image_captions_btn.click(fn=gr_process, inputs=image_files, outputs=output_image_file) | |
combine_files_btn.click(fn=combine_csv_files, inputs=[output_zip_file, output_image_file], outputs=combined_file) | |
demo.launch() | |