joaomorossini commited on
Commit
9daa820
1 Parent(s): 4e405b8

Initial commit to HF spaces

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py CHANGED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Project dependencies ---
2
+ import os
3
+ import io
4
+ import base64
5
+ import requests
6
+ import json
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from dotenv import load_dotenv, find_dotenv
10
+
11
+
12
+ # --- Load environment variables ---
13
+ _ = load_dotenv(find_dotenv()) # read local .env file
14
+ hf_api_key = os.environ["HF_API_KEY"]
15
+
16
+
17
+ # --- Endpoint URLs ---
18
+ endpoint_base_url = "https://api-inference.huggingface.co/models/"
19
+
20
+ endpoints = [
21
+ "Salesforce/blip-image-captioning-large",
22
+ "Salesforce/blip-image-captioning-base",
23
+ "nlpconnect/vit-gpt2-image-captioning",
24
+ ]
25
+
26
+
27
+ # --- Define helper functions ---
28
+
29
+
30
+ # Image-to-text completion
31
+ def get_completion(inputs, parameters=None):
32
+ headers = {
33
+ "Authorization": f"Bearer {hf_api_key}",
34
+ "Content-Type": "application/json",
35
+ }
36
+ data = {"inputs": inputs}
37
+ if parameters is not None:
38
+ data.update({"parameters": parameters})
39
+
40
+ results = {}
41
+ for endpoint in endpoints:
42
+ try:
43
+ response = requests.post(
44
+ endpoint_base_url + endpoint,
45
+ headers=headers,
46
+ data=json.dumps(data),
47
+ )
48
+ response.raise_for_status()
49
+ results[endpoint] = json.loads(response.content.decode("utf-8"))
50
+ except requests.exceptions.RequestException as e:
51
+ print(f"Request to {endpoint} failed: {e}")
52
+ results[endpoint] = {"error": str(e)}
53
+
54
+ return results
55
+
56
+
57
+ # Format image as base64 string
58
+ def image_to_base64_str(pil_image):
59
+ byte_arr = io.BytesIO()
60
+ pil_image.save(byte_arr, format="PNG")
61
+ byte_arr = byte_arr.getvalue()
62
+ return str(base64.b64encode(byte_arr).decode("utf-8"))
63
+
64
+
65
+ # Define captioner function
66
+ def captioner(image):
67
+ base64_image = image_to_base64_str(image)
68
+ results = get_completion(base64_image)
69
+ captions = []
70
+ for endpoint, result in results.items():
71
+ model_name = endpoint.split("/")[-1] # Extract the model name from the endpoint
72
+ if "error" not in result:
73
+ caption = (
74
+ f"**{model_name.upper()}**: \n {result[0]['generated_text']} \n\n\n "
75
+ )
76
+ else:
77
+ caption = f"**{model_name.upper()}**: \n Error - {result['error']} \n\n\n "
78
+ captions.append(caption)
79
+ return "".join(captions) # Join all captions into a single string
80
+
81
+
82
+ # --- Launch the Gradio App ---
83
+ demo = gr.Interface(
84
+ fn=captioner,
85
+ inputs=[gr.Image(label="Upload image", type="pil")],
86
+ outputs=gr.Markdown(label="Captions"),
87
+ title="COMPARE DIFFERENT IMAGE CAPTIONING MODELS",
88
+ description="Upload an image and see how different models caption it",
89
+ allow_flagging="never",
90
+ )
91
+
92
+ demo.launch(share=True, debug=True)
93
+
94
+
95
+ # --- Close all connections ---
96
+ gr.close_all()