NourFakih's picture
Update app.py
803b4d8 verified
import torch
import gradio as gr
from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel
import zipfile
import os
import csv
from PIL import Image
device = 'cpu'
model_name="NourFakih/Vit-GPT2-COCO2017Flickr-40k-05"
# Load the pretrained model, feature extractor, and tokenizer
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def predict(image, max_length=64, num_beams=4):
# Process the input image
image = image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
# Generate the caption
caption_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)[0]
# Decode and clean the generated caption
caption = tokenizer.decode(caption_ids, skip_special_tokens=True)
return caption
def process_images(image_files):
captions = []
for image_file in image_files:
try:
# Open and verify the image
with Image.open(image_file) as img:
caption = predict(img)
captions.append((os.path.basename(image_file), caption))
except Exception as e:
print(f"Skipping file {image_file}: {e}")
# Save the results to a CSV file
csv_file_path = 'image_captions.csv'
with open(csv_file_path, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Image', 'Caption'])
writer.writerows(captions)
return csv_file_path
def process_zip_files(zip_file_paths):
# Create a directory to extract images
extract_dir = 'extracted_images'
os.makedirs(extract_dir, exist_ok=True)
captions = []
for zip_file_path in zip_file_paths:
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
# Verify extracted files and process images
for root, dirs, files in os.walk(extract_dir):
for file in files:
file_path = os.path.join(root, file)
try:
# Open and verify the image
with Image.open(file_path) as img:
caption = predict(img)
captions.append((file, caption))
except Exception as e:
print(f"Skipping file {file}: {e}")
# Save the results to a CSV file
csv_file_path = 'zip_image_captions.csv'
with open(csv_file_path, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Image Name', 'Caption'])
writer.writerows(captions)
return csv_file_path
def gr_process(zip_files, image_files):
if not zip_files and not image_files:
raise ValueError("At least one of zip_files or image_files must be provided.")
elif zip_files:
zip_file_paths = [zip_file.name for zip_file in zip_files]
return process_zip_files(zip_file_paths)
elif image_files:
image_file_paths = [image_file.name for image_file in image_files]
return process_images(image_file_paths)
def combine_csv_files(file1, file2, output_file='combined_captions.csv'):
with open(output_file, mode='w', newline='') as outfile:
writer = csv.writer(outfile)
writer.writerow(['Image Name', 'Caption'])
for file in [file1, file2]:
if os.path.exists(file):
with open(file, mode='r') as infile:
reader = csv.reader(infile)
next(reader) # Skip header row
for row in reader:
writer.writerow(row)
return output_file
css = '''
h1#title {
text-align: center;
}
h3#header {
text-align: center;
}
img#overview {
max-width: 800px;
max-height: 600px;
}
img#style-image {
max-width: 1000px;
max-height: 600px;
}
.gr-image {
max-width: 150px; /* Set a small box for the image */
max-height: 150px;
}
'''
demo = gr.Blocks(css=css)
with demo:
gr.Markdown('''<h1 id="title">Image Caption 🖼️</h1>''')
gr.Markdown('''Made by : No. Fa.''')
zip_files = gr.State([])
image_files = gr.State([])
with gr.Row():
with gr.Column(scale=1):
new_zip_files = gr.File(label="Upload Zip Files", type="filepath", file_count="multiple")
generate_zip_captions_btn = gr.Button("Generate Zip Captions")
new_image_files = gr.File(label="Upload Images", type="filepath", file_count="multiple")
generate_image_captions_btn = gr.Button("Generate Image Captions")
with gr.Column(scale=3):
output_zip_file = gr.File(label="Download Zip Captions")
output_image_file = gr.File(label="Download Image Captions")
combined_file = gr.File(label="Download Combined Captions")
combine_files_btn = gr.Button("Combine CSV Files")
generate_zip_captions_btn.click(fn=gr_process, inputs=new_zip_files, outputs=output_zip_file)
generate_image_captions_btn.click(fn=gr_process, inputs=image_files, outputs=output_image_file)
combine_files_btn.click(fn=combine_csv_files, inputs=[output_zip_file, output_image_file], outputs=combined_file)
demo.launch()