import torch import gradio as gr from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel import zipfile import os import csv from PIL import Image device = 'cpu' model_name="NourFakih/Vit-GPT2-COCO2017Flickr-40k-05" # Load the pretrained model, feature extractor, and tokenizer model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device) feature_extractor = ViTImageProcessor.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) def predict(image, max_length=64, num_beams=4): # Process the input image image = image.convert('RGB') pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) # Generate the caption caption_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)[0] # Decode and clean the generated caption caption = tokenizer.decode(caption_ids, skip_special_tokens=True) return caption def process_images(image_files): captions = [] for image_file in image_files: try: # Open and verify the image with Image.open(image_file) as img: caption = predict(img) captions.append((os.path.basename(image_file), caption)) except Exception as e: print(f"Skipping file {image_file}: {e}") # Save the results to a CSV file csv_file_path = 'image_captions.csv' with open(csv_file_path, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(['Image', 'Caption']) writer.writerows(captions) return csv_file_path def process_zip_files(zip_file_paths): # Create a directory to extract images extract_dir = 'extracted_images' os.makedirs(extract_dir, exist_ok=True) captions = [] for zip_file_path in zip_file_paths: with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) # Verify extracted files and process images for root, dirs, files in os.walk(extract_dir): for file in files: file_path = os.path.join(root, file) try: # Open and verify the image with Image.open(file_path) as img: caption = predict(img) captions.append((file, caption)) except Exception as e: print(f"Skipping file {file}: {e}") # Save the results to a CSV file csv_file_path = 'zip_image_captions.csv' with open(csv_file_path, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(['Image Name', 'Caption']) writer.writerows(captions) return csv_file_path def gr_process(zip_files, image_files): if not zip_files and not image_files: raise ValueError("At least one of zip_files or image_files must be provided.") elif zip_files: zip_file_paths = [zip_file.name for zip_file in zip_files] return process_zip_files(zip_file_paths) elif image_files: image_file_paths = [image_file.name for image_file in image_files] return process_images(image_file_paths) def combine_csv_files(file1, file2, output_file='combined_captions.csv'): with open(output_file, mode='w', newline='') as outfile: writer = csv.writer(outfile) writer.writerow(['Image Name', 'Caption']) for file in [file1, file2]: if os.path.exists(file): with open(file, mode='r') as infile: reader = csv.reader(infile) next(reader) # Skip header row for row in reader: writer.writerow(row) return output_file css = ''' h1#title { text-align: center; } h3#header { text-align: center; } img#overview { max-width: 800px; max-height: 600px; } img#style-image { max-width: 1000px; max-height: 600px; } .gr-image { max-width: 150px; /* Set a small box for the image */ max-height: 150px; } ''' demo = gr.Blocks(css=css) with demo: gr.Markdown('''