File size: 5,379 Bytes
44f5cce
 
 
 
 
 
 
 
 
803b4d8
44f5cce
3f84488
 
 
44f5cce
 
 
 
 
 
 
 
 
 
 
 
 
c29228f
 
 
 
 
 
 
 
 
 
024cdac
924f604
 
 
 
 
 
 
 
c29228f
 
44f5cce
 
 
 
 
c29228f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024cdac
924f604
 
ca2369e
 
f3441d1
ca2369e
924f604
44f5cce
 
924f604
c0548b1
 
 
924f604
 
 
 
 
c0548b1
84468c4
 
 
 
 
 
 
 
 
 
 
 
5faae13
84468c4
cf6a90a
44f5cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024cdac
44f5cce
5faae13
 
 
44f5cce
 
5faae13
290ba63
5faae13
290ba63
84468c4
 
 
 
 
5faae13
290ba63
921ba61
7345e57
84468c4
44f5cce
eac80a1
290ba63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import gradio as gr
from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel
import zipfile
import os
import csv
from PIL import Image

device = 'cpu'
model_name="NourFakih/Vit-GPT2-COCO2017Flickr-40k-05"
# Load the pretrained model, feature extractor, and tokenizer
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def predict(image, max_length=64, num_beams=4):
    # Process the input image
    image = image.convert('RGB')
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
    
    # Generate the caption
    caption_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)[0]
    
    # Decode and clean the generated caption
    caption = tokenizer.decode(caption_ids, skip_special_tokens=True)
    return caption

def process_images(image_files):
    captions = []
    for image_file in image_files:
        try:
            # Open and verify the image
            with Image.open(image_file) as img:
                caption = predict(img)
                captions.append((os.path.basename(image_file), caption))
        except Exception as e:
            print(f"Skipping file {image_file}: {e}")
    
    # Save the results to a CSV file
    csv_file_path = 'image_captions.csv'
    with open(csv_file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Image', 'Caption'])
        writer.writerows(captions)
    
    return csv_file_path

def process_zip_files(zip_file_paths):
    # Create a directory to extract images
    extract_dir = 'extracted_images'
    os.makedirs(extract_dir, exist_ok=True)
    
    captions = []
    for zip_file_path in zip_file_paths:
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(extract_dir)
        
        # Verify extracted files and process images
        for root, dirs, files in os.walk(extract_dir):
            for file in files:
                file_path = os.path.join(root, file)
                try:
                    # Open and verify the image
                    with Image.open(file_path) as img:
                        caption = predict(img)
                        captions.append((file, caption))
                except Exception as e:
                    print(f"Skipping file {file}: {e}")
    
    # Save the results to a CSV file
    csv_file_path = 'zip_image_captions.csv'
    with open(csv_file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Image Name', 'Caption'])
        writer.writerows(captions)
    
    return csv_file_path

def gr_process(zip_files, image_files):
    if not zip_files and not image_files:
        raise ValueError("At least one of zip_files or image_files must be provided.")
    elif zip_files:
        zip_file_paths = [zip_file.name for zip_file in zip_files]
        return process_zip_files(zip_file_paths)
    elif image_files:
        image_file_paths = [image_file.name for image_file in image_files]
        return process_images(image_file_paths)

def combine_csv_files(file1, file2, output_file='combined_captions.csv'):
    with open(output_file, mode='w', newline='') as outfile:
        writer = csv.writer(outfile)
        writer.writerow(['Image Name', 'Caption'])
        
        for file in [file1, file2]:
            if os.path.exists(file):
                with open(file, mode='r') as infile:
                    reader = csv.reader(infile)
                    next(reader)  # Skip header row
                    for row in reader:
                        writer.writerow(row)
    
    return output_file

css = '''
h1#title {
  text-align: center;
}
h3#header {
  text-align: center;
}
img#overview {
  max-width: 800px;
  max-height: 600px;
}
img#style-image {
  max-width: 1000px;
  max-height: 600px;
}
.gr-image {
  max-width: 150px;  /* Set a small box for the image */
  max-height: 150px;
}
'''

demo = gr.Blocks(css=css)

with demo:
    gr.Markdown('''<h1 id="title">Image Caption 🖼️</h1>''')
    gr.Markdown('''Made by : No. Fa.''')
    
    zip_files = gr.State([])
    image_files = gr.State([])

    with gr.Row():
        with gr.Column(scale=1):
            new_zip_files = gr.File(label="Upload Zip Files", type="filepath", file_count="multiple")
            generate_zip_captions_btn = gr.Button("Generate Zip Captions")
            new_image_files = gr.File(label="Upload Images", type="filepath", file_count="multiple")
            generate_image_captions_btn = gr.Button("Generate Image Captions")
        with gr.Column(scale=3):
            output_zip_file = gr.File(label="Download Zip Captions")
            output_image_file = gr.File(label="Download Image Captions")
            combined_file = gr.File(label="Download Combined Captions")
    combine_files_btn = gr.Button("Combine CSV Files")

   
    generate_zip_captions_btn.click(fn=gr_process, inputs=new_zip_files, outputs=output_zip_file)
    generate_image_captions_btn.click(fn=gr_process, inputs=image_files, outputs=output_image_file)
    combine_files_btn.click(fn=combine_csv_files, inputs=[output_zip_file, output_image_file], outputs=combined_file)

demo.launch()