|
import os, base64, json, uuid, torch, gradio as gr |
|
from pathlib import Path |
|
from src.llm.chat import FunctionCallingChat |
|
|
|
chatbot = FunctionCallingChat() |
|
chatbot.temperature = 0.7 |
|
|
|
def image_to_base64(image_path: str): |
|
with open(image_path, "rb") as f: |
|
return base64.b64encode(f.read()).decode("utf-8") |
|
|
|
def save_uploaded_image(pil_img) -> Path: |
|
"""Save PIL image to ./static and return its path.""" |
|
Path("static").mkdir(exist_ok=True) |
|
filename = f"upload_{uuid.uuid4().hex[:8]}.png" |
|
path = Path("static") / filename |
|
pil_img.save(path) |
|
return path |
|
|
|
def inference(pil_img, prompt, task, temperature): |
|
if pil_img is None: |
|
return "β Please upload an image first." |
|
|
|
img_path = save_uploaded_image(pil_img) |
|
chatbot.temperature = temperature |
|
|
|
|
|
if task == "Detection": |
|
user_msg = f"Please detect objects in the image '{img_path}'." |
|
elif task == "Segmentation": |
|
user_msg = f"Please segment objects in the image '{img_path}'." |
|
else: |
|
prompt = prompt.strip() or "Analyse this image." |
|
user_msg = f"{prompt} (image: '{img_path}')" |
|
|
|
try: |
|
out = chatbot(user_msg) |
|
txt = ( |
|
"### π§ Raw tool-call\n" |
|
f"{out['raw_tool_call']}\n\n" |
|
"### π¦ Tool results\n" |
|
f"{json.dumps(out['results'], indent=2)}" |
|
) |
|
return txt |
|
finally: |
|
|
|
try: |
|
img_path.unlink(missing_ok=True) |
|
except Exception: |
|
pass |
|
|
|
def create_header(): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
logo_base64 = image_to_base64("static/aivn_logo.png") |
|
gr.HTML( |
|
f"""<img src="data:image/png;base64,{logo_base64}" |
|
alt="Logo" |
|
style="height:120px;width:auto;margin-right:20px;margin-bottom:20px;">""" |
|
) |
|
with gr.Column(scale=4): |
|
gr.Markdown( |
|
""" |
|
<div style="display:flex;justify-content:space-between;align-items:center;padding:0 15px;"> |
|
<div> |
|
<h1 style="margin-bottom:0;">πΌοΈ Vision Tool-Calling Demo</h1> |
|
<p style="margin-top:0.5em;color:#666;">LLM-driven Detection & Segmentation</p> |
|
</div> |
|
<div style="text-align:right;border-left:2px solid #ddd;padding-left:20px;"> |
|
<h3 style="margin:0;color:#2c3e50;">π AIO2024 Module 10 Project π€</h3> |
|
<p style="margin:0;color:#7f8c8d;">π Using Llama 3.2-1B + YOLO + SAM</p> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
|
|
def create_footer(): |
|
footer_html = """ |
|
<style> |
|
.sticky-footer{position:fixed;bottom:0;left:0;width:100%;background:white; |
|
padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;} |
|
.content-wrap{padding-bottom:60px;} |
|
</style> |
|
<div class="sticky-footer"> |
|
<div style="text-align:center;font-size:14px;"> |
|
Created by <a href="https://vlai.work" target="_blank" |
|
style="color:#007BFF;text-decoration:none;">VLAI</a> β’ AI VIETNAM |
|
</div> |
|
</div> |
|
""" |
|
return gr.HTML(footer_html) |
|
|
|
custom_css = """ |
|
.gradio-container {min-height:100vh;} |
|
.content-wrap {padding-bottom:60px;} |
|
.full-width-btn {width:100%!important;height:50px!important;font-size:18px!important; |
|
margin-top:20px!important;background:linear-gradient(45deg,#FF6B6B,#4ECDC4)!important; |
|
color:white!important;border:none!important;} |
|
.full-width-btn:hover {background:linear-gradient(45deg,#FF5252,#3CB4AC)!important;} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
create_header() |
|
|
|
with gr.Row(equal_height=True, variant="panel"): |
|
with gr.Column(scale=3): |
|
upload_image = gr.Image(label="Upload image", type="pil") |
|
prompt_input = gr.Textbox(label="Optional prompt", placeholder="e.g. Detect cats only") |
|
task_choice = gr.Radio(["Auto", "Detection", "Segmentation"], |
|
value="Auto", label="Task") |
|
|
|
|
|
temp_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.1, |
|
value=0.7, label="Temperature (sampling)") |
|
|
|
submit_btn = gr.Button("Run π§", elem_classes="full-width-btn") |
|
|
|
with gr.Column(scale=4): |
|
output_text = gr.Markdown(label="Result") |
|
|
|
submit_btn.click( |
|
inference, |
|
inputs=[upload_image, prompt_input, task_choice, temp_slider], |
|
outputs=output_text, |
|
) |
|
|
|
create_footer() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(allowed_paths=["static/aivn_logo.png", "static"]) |
|
|