Spaces:
Runtime error
Runtime error
File size: 5,379 Bytes
44f5cce 803b4d8 44f5cce 3f84488 44f5cce c29228f 024cdac 924f604 c29228f 44f5cce c29228f 024cdac 924f604 ca2369e f3441d1 ca2369e 924f604 44f5cce 924f604 c0548b1 924f604 c0548b1 84468c4 5faae13 84468c4 cf6a90a 44f5cce 024cdac 44f5cce 5faae13 44f5cce 5faae13 290ba63 5faae13 290ba63 84468c4 5faae13 290ba63 921ba61 7345e57 84468c4 44f5cce eac80a1 290ba63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import gradio as gr
from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel
import zipfile
import os
import csv
from PIL import Image
device = 'cpu'
model_name="NourFakih/Vit-GPT2-COCO2017Flickr-40k-05"
# Load the pretrained model, feature extractor, and tokenizer
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def predict(image, max_length=64, num_beams=4):
# Process the input image
image = image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
# Generate the caption
caption_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)[0]
# Decode and clean the generated caption
caption = tokenizer.decode(caption_ids, skip_special_tokens=True)
return caption
def process_images(image_files):
captions = []
for image_file in image_files:
try:
# Open and verify the image
with Image.open(image_file) as img:
caption = predict(img)
captions.append((os.path.basename(image_file), caption))
except Exception as e:
print(f"Skipping file {image_file}: {e}")
# Save the results to a CSV file
csv_file_path = 'image_captions.csv'
with open(csv_file_path, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Image', 'Caption'])
writer.writerows(captions)
return csv_file_path
def process_zip_files(zip_file_paths):
# Create a directory to extract images
extract_dir = 'extracted_images'
os.makedirs(extract_dir, exist_ok=True)
captions = []
for zip_file_path in zip_file_paths:
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
# Verify extracted files and process images
for root, dirs, files in os.walk(extract_dir):
for file in files:
file_path = os.path.join(root, file)
try:
# Open and verify the image
with Image.open(file_path) as img:
caption = predict(img)
captions.append((file, caption))
except Exception as e:
print(f"Skipping file {file}: {e}")
# Save the results to a CSV file
csv_file_path = 'zip_image_captions.csv'
with open(csv_file_path, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(['Image Name', 'Caption'])
writer.writerows(captions)
return csv_file_path
def gr_process(zip_files, image_files):
if not zip_files and not image_files:
raise ValueError("At least one of zip_files or image_files must be provided.")
elif zip_files:
zip_file_paths = [zip_file.name for zip_file in zip_files]
return process_zip_files(zip_file_paths)
elif image_files:
image_file_paths = [image_file.name for image_file in image_files]
return process_images(image_file_paths)
def combine_csv_files(file1, file2, output_file='combined_captions.csv'):
with open(output_file, mode='w', newline='') as outfile:
writer = csv.writer(outfile)
writer.writerow(['Image Name', 'Caption'])
for file in [file1, file2]:
if os.path.exists(file):
with open(file, mode='r') as infile:
reader = csv.reader(infile)
next(reader) # Skip header row
for row in reader:
writer.writerow(row)
return output_file
css = '''
h1#title {
text-align: center;
}
h3#header {
text-align: center;
}
img#overview {
max-width: 800px;
max-height: 600px;
}
img#style-image {
max-width: 1000px;
max-height: 600px;
}
.gr-image {
max-width: 150px; /* Set a small box for the image */
max-height: 150px;
}
'''
demo = gr.Blocks(css=css)
with demo:
gr.Markdown('''<h1 id="title">Image Caption 🖼️</h1>''')
gr.Markdown('''Made by : No. Fa.''')
zip_files = gr.State([])
image_files = gr.State([])
with gr.Row():
with gr.Column(scale=1):
new_zip_files = gr.File(label="Upload Zip Files", type="filepath", file_count="multiple")
generate_zip_captions_btn = gr.Button("Generate Zip Captions")
new_image_files = gr.File(label="Upload Images", type="filepath", file_count="multiple")
generate_image_captions_btn = gr.Button("Generate Image Captions")
with gr.Column(scale=3):
output_zip_file = gr.File(label="Download Zip Captions")
output_image_file = gr.File(label="Download Image Captions")
combined_file = gr.File(label="Download Combined Captions")
combine_files_btn = gr.Button("Combine CSV Files")
generate_zip_captions_btn.click(fn=gr_process, inputs=new_zip_files, outputs=output_zip_file)
generate_image_captions_btn.click(fn=gr_process, inputs=image_files, outputs=output_image_file)
combine_files_btn.click(fn=combine_csv_files, inputs=[output_zip_file, output_image_file], outputs=combined_file)
demo.launch()
|