cd2412 commited on
Commit
602cad6
·
verified ·
1 Parent(s): 2e24156

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ VQA — Memory + RL Controller (Gradio app)
4
+ - Drag-and-drop an image, ask a question, and see the model's answer + chosen strategy.
5
+ - Tries to import `answer_with_controller` from controller.py. Falls back to a stub if missing.
6
+ - Works on Hugging Face Spaces, Render, Docker, or local run.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import time
12
+ import traceback
13
+ import subprocess
14
+ from typing import Tuple, Optional
15
+
16
+ # Ensure gradio is available when running locally; Spaces installs from requirements.txt
17
+ try:
18
+ import gradio as gr
19
+ except ImportError: # pragma: no cover
20
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "gradio"])
21
+ import gradio as gr
22
+
23
+ from PIL import Image
24
+
25
+ # -----------------------------
26
+ # Attempt to import real handler
27
+ # -----------------------------
28
+ def _make_fallback():
29
+ def _fallback_answer_with_controller(
30
+ image: Image.Image,
31
+ question: str,
32
+ source: str = "auto",
33
+ distilled_model: str = "auto",
34
+ ) -> Tuple[str, str, int]:
35
+ # Replace with real inference to remove this placeholder.
36
+ return "Placeholder answer (wire your models in controller.py).", "baseline", 0
37
+ return _fallback_answer_with_controller
38
+
39
+ try:
40
+ # Expect controller.py to define: answer_with_controller(image, question, source, distilled_model)
41
+ from controller import answer_with_controller # type: ignore
42
+ except Exception as e:
43
+ print(f"[WARN] Using fallback controller because import failed: {e}", flush=True)
44
+ answer_with_controller = _make_fallback()
45
+
46
+ # -----------------------------
47
+ # UI Constants
48
+ # -----------------------------
49
+ TITLE = "VQA — Memory + RL Controller"
50
+ DESCRIPTION = (
51
+ "Upload an image, enter a question, and the controller will choose the best decoding strategy."
52
+ )
53
+
54
+ CONTROLLER_SOURCES = ["auto", "distilled", "ppo", "baseline"]
55
+ DISTILLED_CHOICES = ["auto", "logreg", "mlp32"]
56
+
57
+ # -----------------------------
58
+ # Inference wrapper with guards
59
+ # -----------------------------
60
+ def vqa_demo_fn(
61
+ image: Optional[Image.Image],
62
+ question: str,
63
+ source: str,
64
+ distilled_model: str,
65
+ ) -> Tuple[str, str, float]:
66
+ """Safely run inference and return (answer, strategy_label, latency_ms)."""
67
+ # Input validation
68
+ if image is None:
69
+ return "Please upload an image.", "", 0.0
70
+ question = (question or "").strip()
71
+ if not question:
72
+ return "Please enter a question.", "", 0.0
73
+
74
+ # Convert & measure latency
75
+ t0 = time.perf_counter()
76
+ try:
77
+ # Convert to RGB to avoid issues with PNG/L mode
78
+ image_rgb = image.convert("RGB")
79
+
80
+ pred, strategy_name, action_id = answer_with_controller(
81
+ image_rgb,
82
+ question,
83
+ source=source,
84
+ distilled_model=distilled_model,
85
+ )
86
+
87
+ latency_ms = (time.perf_counter() - t0) * 1000.0
88
+ # Friendly formatting
89
+ strategy_out = f"{action_id} → {strategy_name}"
90
+ return str(pred), strategy_out, round(latency_ms, 1)
91
+
92
+ except Exception as err:
93
+ # Never crash the app — show a concise error to the user and log details to server
94
+ latency_ms = (time.perf_counter() - t0) * 1000.0
95
+ print("[ERROR] Inference failed:\n" + "".join(traceback.format_exc()), flush=True)
96
+ return f"Error: {err}", "error", round(latency_ms, 1)
97
+
98
+ # -----------------------------
99
+ # Build Gradio Interface
100
+ # -----------------------------
101
+ with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
102
+ gr.Markdown(f"### {TITLE}\n{DESCRIPTION}")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ img_in = gr.Image(
107
+ type="pil",
108
+ label="Image",
109
+ height=320,
110
+ sources=["upload", "drag-and-drop", "clipboard", "webcam"],
111
+ image_mode="RGB",
112
+ )
113
+ q_in = gr.Textbox(
114
+ label="Question",
115
+ placeholder="e.g., What colour is the bus?",
116
+ lines=2,
117
+ max_lines=4,
118
+ )
119
+ source_in = gr.Radio(
120
+ CONTROLLER_SOURCES,
121
+ value="auto",
122
+ label="Controller Source",
123
+ )
124
+ dist_in = gr.Radio(
125
+ DISTILLED_CHOICES,
126
+ value="auto",
127
+ label="Distilled Gate (if used)",
128
+ )
129
+ run_btn = gr.Button("Predict", variant="primary")
130
+ with gr.Column():
131
+ ans_out = gr.Textbox(label="Answer", interactive=False, lines=3, max_lines=6)
132
+ strat_out = gr.Textbox(label="Chosen Strategy", interactive=False)
133
+ lat_out = gr.Number(label="Latency (ms)", precision=1, interactive=False)
134
+
135
+ run_btn.click(
136
+ vqa_demo_fn,
137
+ inputs=[img_in, q_in, source_in, dist_in],
138
+ outputs=[ans_out, strat_out, lat_out],
139
+ api_name="predict",
140
+ )
141
+
142
+ # -----------------------------
143
+ # Launch
144
+ # -----------------------------
145
+ if __name__ == "__main__":
146
+ # Respect $PORT for Spaces/Render/Docker; default to 7860 locally
147
+ port = int(os.getenv("PORT", "7860"))
148
+ # Queue improves robustness under load
149
+ demo.queue(concurrency_count=2)
150
+ demo.launch(
151
+ server_name="0.0.0.0",
152
+ server_port=port,
153
+ share=False, # set True only for local quick sharing
154
+ show_error=True,
155
+ )