Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,9 +18,6 @@ from transformers import (
|
|
| 18 |
T5Tokenizer,
|
| 19 |
)
|
| 20 |
|
| 21 |
-
# -------------------------------------------------
|
| 22 |
-
# Device & models
|
| 23 |
-
# -------------------------------------------------
|
| 24 |
device = torch.device("cpu")
|
| 25 |
|
| 26 |
IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
|
|
@@ -34,13 +31,12 @@ rewriter_tok = T5Tokenizer.from_pretrained(TXT_MODEL)
|
|
| 34 |
rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
|
| 35 |
|
| 36 |
|
| 37 |
-
# -------------------------------------------------
|
| 38 |
-
# Helpers
|
| 39 |
-
# -------------------------------------------------
|
| 40 |
def load_image(url: str):
|
| 41 |
"""Return (PIL.Image, None) or (None, error). Handles http/https and data‑URL."""
|
| 42 |
try:
|
| 43 |
-
url = url.strip()
|
|
|
|
|
|
|
| 44 |
if url.startswith("data:"):
|
| 45 |
_, data = url.split(",", 1)
|
| 46 |
img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
|
|
@@ -55,10 +51,8 @@ def load_image(url: str):
|
|
| 55 |
|
| 56 |
|
| 57 |
def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
|
| 58 |
-
"""Return the longest caption (most detailed) from the vision model."""
|
| 59 |
inputs = processor(images=img, return_tensors="pt")
|
| 60 |
pix = inputs.pixel_values.to(device)
|
| 61 |
-
|
| 62 |
if sample:
|
| 63 |
out = vision.generate(
|
| 64 |
pix,
|
|
@@ -83,12 +77,10 @@ def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
|
|
| 83 |
|
| 84 |
|
| 85 |
def expand_caption(base: str, prompt: str = None, max_len=160):
|
| 86 |
-
"""Use T5 to expand the base caption."""
|
| 87 |
if prompt and prompt.strip():
|
| 88 |
instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
|
| 89 |
else:
|
| 90 |
instr = f"Expand with rich visual detail. Caption: \"{base}\""
|
| 91 |
-
|
| 92 |
toks = rewriter_tok(
|
| 93 |
instr,
|
| 94 |
return_tensors="pt",
|
|
@@ -96,7 +88,6 @@ def expand_caption(base: str, prompt: str = None, max_len=160):
|
|
| 96 |
padding="max_length",
|
| 97 |
max_length=256,
|
| 98 |
).to(device)
|
| 99 |
-
|
| 100 |
out = rewriter.generate(
|
| 101 |
**toks,
|
| 102 |
max_length=max_len,
|
|
@@ -108,36 +99,26 @@ def expand_caption(base: str, prompt: str = None, max_len=160):
|
|
| 108 |
|
| 109 |
|
| 110 |
def async_expand(base, prompt, max_len, status):
|
| 111 |
-
"""Background expansion; updates status dict."""
|
| 112 |
try:
|
| 113 |
status["text"] = "Expanding…"
|
| 114 |
-
time.sleep(0.1)
|
| 115 |
result = expand_caption(base, prompt, max_len)
|
| 116 |
status["text"] = "Done"
|
| 117 |
-
|
| 118 |
except Exception as e:
|
| 119 |
status["text"] = f"Error: {e}"
|
| 120 |
-
|
| 121 |
|
| 122 |
|
| 123 |
-
# -------------------------------------------------
|
| 124 |
-
# Gradio callbacks
|
| 125 |
-
# -------------------------------------------------
|
| 126 |
def fast_describe(url, prompt, detail, beams, sample):
|
| 127 |
img, err = load_image(url)
|
| 128 |
if err:
|
| 129 |
return None, "", err
|
| 130 |
-
|
| 131 |
detail_map = {"Low": 80, "Medium": 140, "High": 220}
|
| 132 |
max_expand = detail_map.get(detail, 140)
|
| 133 |
-
|
| 134 |
base = generate_base(img, beams=beams, sample=sample)
|
| 135 |
-
status = {"text": "Queued…"}
|
| 136 |
-
|
| 137 |
-
def worker():
|
| 138 |
-
status["final"] = async_expand(base, prompt, max_expand, status)
|
| 139 |
-
|
| 140 |
-
threading.Thread(target=worker, daemon=True).start()
|
| 141 |
return img, base, status["text"]
|
| 142 |
|
| 143 |
|
|
@@ -145,10 +126,8 @@ def final_caption(url, prompt, detail, beams, sample):
|
|
| 145 |
img, err = load_image(url)
|
| 146 |
if err:
|
| 147 |
return "", err
|
| 148 |
-
|
| 149 |
detail_map = {"Low": 80, "Medium": 140, "High": 220}
|
| 150 |
max_expand = detail_map.get(detail, 140)
|
| 151 |
-
|
| 152 |
base = generate_base(img, beams=beams, sample=sample)
|
| 153 |
try:
|
| 154 |
final = expand_caption(base, prompt, max_expand)
|
|
@@ -157,9 +136,6 @@ def final_caption(url, prompt, detail, beams, sample):
|
|
| 157 |
return base, f"Expand error: {e}"
|
| 158 |
|
| 159 |
|
| 160 |
-
# -------------------------------------------------
|
| 161 |
-
# UI
|
| 162 |
-
# -------------------------------------------------
|
| 163 |
css = "footer {display:none !important;}"
|
| 164 |
with gr.Blocks() as demo:
|
| 165 |
gr.Markdown("## Image Describer")
|
|
@@ -189,15 +165,6 @@ with gr.Blocks() as demo:
|
|
| 189 |
outputs=[caption_out, status_out],
|
| 190 |
)
|
| 191 |
|
| 192 |
-
# -------------------------------------------------
|
| 193 |
-
# Launch
|
| 194 |
-
# -------------------------------------------------
|
| 195 |
if __name__ == "__main__":
|
| 196 |
demo.queue()
|
| 197 |
-
demo.launch(
|
| 198 |
-
server_name="0.0.0.0",
|
| 199 |
-
server_port=7860,
|
| 200 |
-
css=css,
|
| 201 |
-
title="Image Describer (CPU)",
|
| 202 |
-
prevent_thread_lock=True,
|
| 203 |
-
)
|
|
|
|
| 18 |
T5Tokenizer,
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
device = torch.device("cpu")
|
| 22 |
|
| 23 |
IMG_MODEL = "nlpconnect/vit-gpt2-image-captioning"
|
|
|
|
| 31 |
rewriter = T5ForConditionalGeneration.from_pretrained(TXT_MODEL).to(device).eval()
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
def load_image(url: str):
|
| 35 |
"""Return (PIL.Image, None) or (None, error). Handles http/https and data‑URL."""
|
| 36 |
try:
|
| 37 |
+
url = (url or "").strip()
|
| 38 |
+
if not url:
|
| 39 |
+
return None, "No URL provided."
|
| 40 |
if url.startswith("data:"):
|
| 41 |
_, data = url.split(",", 1)
|
| 42 |
img = Image.open(BytesIO(base64.b64decode(data))).convert("RGB")
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def generate_base(img: Image.Image, max_len=40, beams=2, sample=False):
|
|
|
|
| 54 |
inputs = processor(images=img, return_tensors="pt")
|
| 55 |
pix = inputs.pixel_values.to(device)
|
|
|
|
| 56 |
if sample:
|
| 57 |
out = vision.generate(
|
| 58 |
pix,
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def expand_caption(base: str, prompt: str = None, max_len=160):
|
|
|
|
| 80 |
if prompt and prompt.strip():
|
| 81 |
instr = f"Expand using: '{prompt}'. Caption: \"{base}\""
|
| 82 |
else:
|
| 83 |
instr = f"Expand with rich visual detail. Caption: \"{base}\""
|
|
|
|
| 84 |
toks = rewriter_tok(
|
| 85 |
instr,
|
| 86 |
return_tensors="pt",
|
|
|
|
| 88 |
padding="max_length",
|
| 89 |
max_length=256,
|
| 90 |
).to(device)
|
|
|
|
| 91 |
out = rewriter.generate(
|
| 92 |
**toks,
|
| 93 |
max_length=max_len,
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
def async_expand(base, prompt, max_len, status):
|
|
|
|
| 102 |
try:
|
| 103 |
status["text"] = "Expanding…"
|
| 104 |
+
time.sleep(0.1)
|
| 105 |
result = expand_caption(base, prompt, max_len)
|
| 106 |
status["text"] = "Done"
|
| 107 |
+
status["final"] = result
|
| 108 |
except Exception as e:
|
| 109 |
status["text"] = f"Error: {e}"
|
| 110 |
+
status["final"] = base
|
| 111 |
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
def fast_describe(url, prompt, detail, beams, sample):
|
| 114 |
img, err = load_image(url)
|
| 115 |
if err:
|
| 116 |
return None, "", err
|
|
|
|
| 117 |
detail_map = {"Low": 80, "Medium": 140, "High": 220}
|
| 118 |
max_expand = detail_map.get(detail, 140)
|
|
|
|
| 119 |
base = generate_base(img, beams=beams, sample=sample)
|
| 120 |
+
status = {"text": "Queued…", "final": ""}
|
| 121 |
+
threading.Thread(target=async_expand, args=(base, prompt, max_expand, status), daemon=True).start()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
return img, base, status["text"]
|
| 123 |
|
| 124 |
|
|
|
|
| 126 |
img, err = load_image(url)
|
| 127 |
if err:
|
| 128 |
return "", err
|
|
|
|
| 129 |
detail_map = {"Low": 80, "Medium": 140, "High": 220}
|
| 130 |
max_expand = detail_map.get(detail, 140)
|
|
|
|
| 131 |
base = generate_base(img, beams=beams, sample=sample)
|
| 132 |
try:
|
| 133 |
final = expand_caption(base, prompt, max_expand)
|
|
|
|
| 136 |
return base, f"Expand error: {e}"
|
| 137 |
|
| 138 |
|
|
|
|
|
|
|
|
|
|
| 139 |
css = "footer {display:none !important;}"
|
| 140 |
with gr.Blocks() as demo:
|
| 141 |
gr.Markdown("## Image Describer")
|
|
|
|
| 165 |
outputs=[caption_out, status_out],
|
| 166 |
)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
| 168 |
if __name__ == "__main__":
|
| 169 |
demo.queue()
|
| 170 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, css=css, prevent_thread_lock=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|