File size: 9,589 Bytes
ed6bf09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e20f8d7
 
 
ed6bf09
 
 
 
 
 
 
 
 
 
e20f8d7
ed6bf09
 
 
 
 
 
 
 
 
 
e20f8d7
ed6bf09
 
 
 
 
 
 
e20f8d7
ed6bf09
 
 
 
e20f8d7
ed6bf09
 
 
 
 
e20f8d7
ed6bf09
 
 
 
e20f8d7
ed6bf09
 
 
 
 
 
 
e20f8d7
ed6bf09
 
 
 
 
 
e20f8d7
ed6bf09
 
 
 
 
 
 
 
 
e20f8d7
ed6bf09
 
e20f8d7
ed6bf09
 
 
e20f8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed6bf09
e20f8d7
 
 
 
 
 
 
ed6bf09
 
e20f8d7
ed6bf09
 
e20f8d7
ed6bf09
 
 
 
 
 
e20f8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# import gradio as gr
# from PIL import Image
# import torch
# from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM

# # Set device to CPU and default dtype to float32
# DEVICE = torch.device("cpu")
# torch.set_default_dtype(torch.float32)

# # Load CLIP model and processor
# try:
#     clip_model = CLIPModel.from_pretrained(
#         "openai/clip-vit-base-patch32",
#         torch_dtype=torch.float32
#     ).to(DEVICE)
#     clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# except Exception as e:
#     raise Exception(f"Error loading CLIP model or processor: {str(e)}")

# # Load language model and tokenizer
# def load_model():
#     try:
#         # Use a lightweight model suitable for CPU (distilgpt2 for lower memory)
#         #model_name = "distilgpt2"  # Switched to distilgpt2 for better CPU performance
#         model_name="microsoft/phi-3-mini-4k-instruct"
#         model = AutoModelForCausalLM.from_pretrained(
#             model_name,
#             torch_dtype=torch.float32,
#             trust_remote_code=True
#         ).to(DEVICE)
#         tokenizer = AutoTokenizer.from_pretrained(
#             model_name,
#             trust_remote_code=True
#         )

#         # Set pad token if not defined
#         if tokenizer.pad_token is None:
#             tokenizer.pad_token = tokenizer.eos_token
#             model.config.pad_token_id = model.config.eos_token_id

#         model.eval()
#         return model, tokenizer

#     except Exception as e:
#         raise Exception(f"Error loading language model: {str(e)}")

# # Simple multimodal captioning function
# def generate_caption(image, model, tokenizer):
#     try:
#         if not isinstance(image, Image.Image):
#             return "Error: Input must be a valid image."
#         if image.mode != "RGB":
#             image = image.convert("RGB")

#         # Process image with CLIP
#         image_inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
#         with torch.no_grad():
#             image_embedding = clip_model.get_image_features(**image_inputs).to(torch.float32)

#         # Prepare prompt
#         prompt = "Caption this image:"
#         inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
#         input_ids = inputs["input_ids"].to(DEVICE)
#         attention_mask = inputs["attention_mask"].to(DEVICE)

#         # Simple projection: use image embedding as a prefix
#         projection = torch.nn.Linear(512, model.config.hidden_size).to(DEVICE)
#         with torch.no_grad():
#             image_embedding_projected = projection(image_embedding)

#         # Combine image and text embeddings
#         text_embedding = model.get_input_embeddings()(input_ids)
#         fused_embedding = torch.cat([image_embedding_projected.unsqueeze(1), text_embedding], dim=1)
#         attention_mask = torch.cat([
#             torch.ones(input_ids.size(0), 1, device=DEVICE),
#             attention_mask
#         ], dim=1)

#         # Generate caption
#         with torch.no_grad():
#             generated_ids = model.generate(
#                 inputs_embeds=fused_embedding,
#                 attention_mask=attention_mask,
#                 max_new_tokens=50,
#                 min_length=10,
#                 num_beams=3,  # Reduced for CPU speed
#                 repetition_penalty=1.2,
#                 do_sample=False
#             )
#         caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
#         return caption.strip()

#     except Exception as e:
#         return f"Error generating caption: {str(e)}"

# # Load model and tokenizer
# model, tokenizer = load_model()

# # Gradio interface with explicit component configuration
# def gradio_caption(image):
#     if image is None:
#         return "Please upload an image."
#     result = generate_caption(image, model, tokenizer)
#     return result if isinstance(result, str) else str(result)

# # Define components explicitly to avoid schema issues
# inputs = gr.Image(
#     type="pil",
#     label="Upload an Image",
#     sources=["upload"],  # Restrict to uploads to simplify schema
# )
# outputs = gr.Textbox(
#     label="Generated Caption",
#     lines=2,
#     placeholder="Caption will appear here..."
# )

# # Use gr.Blocks for finer control instead of gr.Interface
# with gr.Blocks(title="CPU-Based Image Captioning") as interface:
#     gr.Markdown(
#         """
#         # CPU-Based Image Captioning with CLIP and DistilGPT2
#         Upload an image to generate a caption using a lightweight multimodal model.
#         This app runs on CPU and may produce basic captions due to simplified processing.
#         """
#     )
#     with gr.Row():
#         with gr.Column():
#             image_input = inputs
#             submit_button = gr.Button("Generate Caption")
#         with gr.Column():
#             caption_output = outputs
#     submit_button.click(
#         fn=gradio_caption,
#         inputs=image_input,
#         outputs=caption_output
#     )

# # Launch locally with debugging enabled
# interface.launch(debug=True)




import gradio as gr
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM

# Set device to CPU and default dtype to float32
DEVICE = torch.device("cpu")
torch.set_default_dtype(torch.float32)

# Load CLIP model and processor
try:
    clip_model = CLIPModel.from_pretrained(
        "openai/clip-vit-base-patch32",
        torch_dtype=torch.float32
    ).to(DEVICE)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
except Exception as e:
    raise Exception(f"Error loading CLIP model or processor: {str(e)}")

# Load language model and tokenizer
def load_model():
    try:
        #model_name = "distilgpt2"
        model_name="microsoft/phi-3-mini-4k-instruct"
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            trust_remote_code=True
        ).to(DEVICE)
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id

        model.eval()
        return model, tokenizer

    except Exception as e:
        raise Exception(f"Error loading language model: {str(e)}")

# Caption generation logic
def generate_caption(image, model, tokenizer):
    try:
        # Ensure the image is a PIL Image and convert to RGB if necessary
        if not isinstance(image, Image.Image):
            image = Image.frombytes('RGB', image.size, image.rgb) if hasattr(image, 'rgb') else image
        else:
            # Convert to RGB if the image has a different mode (e.g., RGBA, L)
            if image.mode != 'RGB':
                image = image.convert('RGB')


        image_inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            image_embedding = clip_model.get_image_features(**image_inputs).to(torch.float32)

        prompt = "[IMG] Caption this image:"
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs["input_ids"].to(DEVICE)
        attention_mask = inputs["attention_mask"].to(DEVICE)

        projection = torch.nn.Linear(512, model.config.hidden_size).to(DEVICE)
        with torch.no_grad():
            image_embedding_projected = projection(image_embedding)

        text_embedding = model.get_input_embeddings()(input_ids)
        fused_embedding = torch.cat([image_embedding_projected.unsqueeze(1), text_embedding], dim=1)
        attention_mask = torch.cat([
            torch.ones(input_ids.size(0), 1, device=DEVICE),
            attention_mask
        ], dim=1)

        with torch.no_grad():
            generated_ids = model.generate(
                inputs_embeds=fused_embedding,
                attention_mask=attention_mask,
                max_new_tokens=50,
                min_length=10,
                num_beams=3,
                repetition_penalty=1.2,
                do_sample=False
            )
        caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        return caption.strip()

    except Exception as e:
        return f"Error generating caption: {str(e)}"

# Load model/tokenizer
model, tokenizer = load_model()

# Wrapper for Gradio function call
def gradio_caption(image):
    if image is None:
        return "Please upload an image."
    return generate_caption(image, model, tokenizer)

# Reusable UI component blocks
def create_image_input():
    return gr.Image(
        type="pil",
        label="Upload an Image",
        sources=["upload"]
    )

def create_caption_output():
    return gr.Textbox(
        label="Generated Caption",
        lines=2,
        placeholder="Caption will appear here..."
    )

# Build UI
interface = gr.Interface(
    fn=gradio_caption,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Textbox(label="Generated Caption"),
    title="Image Captioning with Fine-Tuned MultiModalModel (Epoch 0)",
    description=(
        "Upload an image to generate a caption using a fine-tuned multimodal model based on Phi-3 and CLIP. "
         "The weights from Epoch_0 are used here, but the model may not generate accurate captions due to limited training."
    )
)
interface.launch()