# --- Project dependencies --- import os import io import base64 import requests import json import gradio as gr from PIL import Image from dotenv import load_dotenv, find_dotenv # --- Load environment variables --- _ = load_dotenv(find_dotenv()) # read local .env file hf_api_key = os.environ["HF_API_KEY"] # --- Endpoint URLs --- endpoint_base_url = "https://api-inference.huggingface.co/models/" endpoints = [ "Salesforce/blip-image-captioning-large", "Salesforce/blip-image-captioning-base", "nlpconnect/vit-gpt2-image-captioning", ] # --- Define helper functions --- # Image-to-text completion def get_completion(inputs, parameters=None): headers = { "Authorization": f"Bearer {hf_api_key}", "Content-Type": "application/json", } data = {"inputs": inputs} if parameters is not None: data.update({"parameters": parameters}) results = {} for endpoint in endpoints: try: response = requests.post( endpoint_base_url + endpoint, headers=headers, data=json.dumps(data), ) response.raise_for_status() results[endpoint] = json.loads(response.content.decode("utf-8")) except requests.exceptions.RequestException as e: print(f"Request to {endpoint} failed: {e}") results[endpoint] = {"error": str(e)} return results # Format image as base64 string def image_to_base64_str(pil_image): byte_arr = io.BytesIO() pil_image.save(byte_arr, format="PNG") byte_arr = byte_arr.getvalue() return str(base64.b64encode(byte_arr).decode("utf-8")) # Define captioner function def captioner(image): base64_image = image_to_base64_str(image) results = get_completion(base64_image) captions = [] for endpoint, result in results.items(): model_name = endpoint.split("/")[-1] # Extract the model name from the endpoint if "error" not in result: caption = ( f"**{model_name.upper()}**: \n {result[0]['generated_text']} \n\n\n " ) else: caption = f"**{model_name.upper()}**: \n Error - {result['error']} \n\n\n " captions.append(caption) return "".join(captions) # Join all captions into a single string # --- Launch the Gradio App --- demo = gr.Interface( fn=captioner, inputs=[gr.Image(label="Upload image", type="pil")], outputs=gr.Markdown(label="Captions"), title="COMPARE DIFFERENT IMAGE CAPTIONING MODELS", description="Upload an image and see how different models caption it", allow_flagging="never", ) demo.launch(share=True, debug=True) # --- Close all connections --- gr.close_all()