bhsinghgrid commited on
Commit
2f09c83
·
verified ·
1 Parent(s): 27f26fd

Fix UI load flow and align generation logic with inference.py

Browse files
Files changed (1) hide show
  1. app.py +502 -174
app.py CHANGED
@@ -1,235 +1,563 @@
1
- """
2
- Hugging Face Space app for Sanskrit D3PM project.
3
-
4
- Deploy on Spaces with:
5
- app_file = app_hf_space.py
6
-
7
- Optional environment variables:
8
- HF_CHECKPOINT_REPO : model repo id (e.g. "username/sanskrit-d3pm")
9
- HF_CHECKPOINT_FILE : checkpoint path in repo (default: "best_model.pt")
10
- HF_CHECKPOINT_LABEL : UI label for remote checkpoint
11
- """
12
-
13
- from __future__ import annotations
14
-
15
  import copy
 
16
  import os
17
- from typing import Dict, Tuple
 
 
18
 
19
  import gradio as gr
20
  import torch
21
 
22
  from config import CONFIG
23
- from inference import _build_tokenizers, _resolve_device, load_model, run_inference
 
24
 
25
 
26
- def _clean_output(text: str, max_repeat: int = 2) -> str:
27
- text = " ".join(text.split())
28
- if not text:
29
- return text
30
- toks = text.split()
31
- out = []
32
- prev = None
33
- run = 0
34
- for t in toks:
35
- if t == prev:
36
- run += 1
37
- else:
38
- prev = t
39
- run = 1
40
- if run <= max_repeat:
41
- out.append(t)
42
- s = " ".join(out)
43
- s = s.replace(" ।", "।").replace(" ॥", "॥")
44
- return " ".join(s.split())
45
 
46
 
47
- def _discover_local_checkpoints() -> Dict[str, str]:
48
- found = {}
49
  for root in ("ablation_results", "results7", "results"):
50
  if not os.path.isdir(root):
51
  continue
52
- for exp in sorted(os.listdir(root)):
53
- ckpt = os.path.join(root, exp, "best_model.pt")
54
- if os.path.exists(ckpt):
55
- found[f"{exp} [{root}]"] = ckpt
 
 
 
 
 
 
 
 
56
  return found
57
 
58
 
59
- def _discover_remote_checkpoint() -> Dict[str, str]:
60
- repo = os.getenv("HF_CHECKPOINT_REPO", "").strip()
61
- if not repo:
62
- return {}
63
-
64
- filename = os.getenv("HF_CHECKPOINT_FILE", "best_model.pt").strip()
65
- label = os.getenv("HF_CHECKPOINT_LABEL", f"remote:{repo}")
66
 
67
- try:
68
- from huggingface_hub import hf_hub_download
69
 
70
- ckpt_path = hf_hub_download(repo_id=repo, filename=filename)
71
- return {label: ckpt_path}
72
- except Exception as e:
73
- print(f"[WARN] remote checkpoint download failed: {e}")
74
- return {}
 
 
 
75
 
76
 
77
- def _infer_model_type(path: str) -> str:
78
- p = path.lower()
79
- if "d3pm_encoder_decoder" in p:
 
 
 
80
  return "d3pm_encoder_decoder"
81
- if "baseline_cross_attention" in p:
82
  return "baseline_cross_attention"
83
- if "baseline_encoder_decoder" in p:
84
  return "baseline_encoder_decoder"
85
- return "d3pm_cross_attention"
86
 
87
 
88
- def _infer_neg(path: str) -> bool:
89
- p = path.lower()
90
- if "_neg_true" in p:
 
91
  return True
92
- if "_neg_false" in p:
93
  return False
94
  return CONFIG["data"]["include_negative_examples"]
95
 
96
 
97
- class RuntimeStore:
98
- def __init__(self):
99
- self.loaded: Dict[str, Dict] = {}
100
-
101
- def get(self, ckpt_label: str, ckpt_path: str) -> Dict:
102
- if ckpt_label in self.loaded:
103
- return self.loaded[ckpt_label]
104
-
105
- cfg = copy.deepcopy(CONFIG)
106
- cfg["model_type"] = _infer_model_type(ckpt_path)
107
- cfg["data"]["include_negative_examples"] = _infer_neg(ckpt_path)
108
- device = _resolve_device(cfg)
109
-
110
- model, cfg = load_model(ckpt_path, cfg, device)
111
- src_tok, tgt_tok = _build_tokenizers(cfg)
112
-
113
- bundle = {
114
- "label": ckpt_label,
115
- "path": ckpt_path,
116
- "cfg": cfg,
117
- "device": str(device),
118
- "model": model,
119
- "src_tok": src_tok,
120
- "tgt_tok": tgt_tok,
121
- }
122
- self.loaded[ckpt_label] = bundle
123
- return bundle
124
-
125
-
126
- RUNTIME = RuntimeStore()
127
- CHECKPOINTS = {}
128
- CHECKPOINTS.update(_discover_local_checkpoints())
129
- CHECKPOINTS.update(_discover_remote_checkpoint())
130
-
131
- if not CHECKPOINTS:
132
- CHECKPOINTS = {"No checkpoint found": ""}
133
-
134
-
135
- def load_checkpoint_ui(label: str) -> Tuple[Dict, str]:
136
- if label not in CHECKPOINTS or not CHECKPOINTS[label]:
137
- raise gr.Error("No valid checkpoint found. Upload/provide best_model.pt first.")
138
- bundle = RUNTIME.get(label, CHECKPOINTS[label])
139
- info = (
140
- f"Loaded `{label}`\n"
141
- f"- path: `{bundle['path']}`\n"
142
- f"- model_type: `{bundle['cfg']['model_type']}`\n"
143
- f"- device: `{bundle['device']}`\n"
144
- f"- max_seq_len: `{bundle['cfg']['model']['max_seq_len']}`"
145
  )
146
- return bundle, info
147
-
148
-
149
- def generate_ui(
150
- bundle: Dict,
151
- text: str,
152
- temperature: float,
153
- top_k: int,
154
- repetition_penalty: float,
155
- diversity_penalty: float,
156
- num_steps: int,
157
- clean_output: bool,
158
- ) -> str:
159
- if not bundle:
160
- raise gr.Error("Load a checkpoint first.")
161
- if not text.strip():
162
- raise gr.Error("Enter input text.")
163
-
164
- cfg = copy.deepcopy(bundle["cfg"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  cfg["inference"]["temperature"] = float(temperature)
166
  cfg["inference"]["top_k"] = int(top_k)
167
  cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
168
  cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
169
  cfg["inference"]["num_steps"] = int(num_steps)
170
 
171
- src_tok = bundle["src_tok"]
172
- tgt_tok = bundle["tgt_tok"]
173
- device = torch.device(bundle["device"])
174
- ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
175
 
176
- out = run_inference(bundle["model"], ids, cfg)
177
- token_ids = [x for x in out[0].tolist() if x > 4]
178
- pred = tgt_tok.decode(token_ids).strip()
 
 
 
 
 
 
179
  if clean_output:
180
- pred = _clean_output(pred)
181
- return pred if pred else "(empty output)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
 
184
- with gr.Blocks(title="Sanskrit D3PM Space") as demo:
185
  model_state = gr.State(None)
 
186
  gr.Markdown(
187
  """
188
- ## Sanskrit D3PM Paraphrase (IAST → Devanagari)
189
- Load a trained checkpoint and generate output from Roman/IAST Sanskrit input.
 
 
 
 
190
  """
191
  )
192
 
193
- checkpoint = gr.Dropdown(
194
- choices=list(CHECKPOINTS.keys()),
195
- value=list(CHECKPOINTS.keys())[0],
196
- label="Checkpoint",
197
- )
198
- load_btn = gr.Button("Load Model", variant="primary")
199
- load_info = gr.Markdown("Select a checkpoint and click **Load Model**.")
200
-
201
- text_in = gr.Textbox(label="Input (Roman / IAST)", lines=3, value="dharmo rakṣati rakṣitaḥ")
202
- text_out = gr.Textbox(label="Output (Devanagari)", lines=6)
203
-
204
  with gr.Row():
205
- temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
206
- top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
207
- repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
208
- diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
209
- num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
210
- clean_output = gr.Checkbox(value=True, label="Clean Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- generate_btn = gr.Button("Generate", variant="primary")
 
 
 
 
213
 
214
- load_btn.click(load_checkpoint_ui, inputs=[checkpoint], outputs=[model_state, load_info])
215
  generate_btn.click(
216
- generate_ui,
217
  inputs=[
218
- model_state, text_in, temperature, top_k, repetition_penalty,
219
- diversity_penalty, num_steps, clean_output
 
 
 
 
 
 
220
  ],
221
- outputs=[text_out],
222
  )
223
- text_in.submit(
224
- generate_ui,
225
  inputs=[
226
- model_state, text_in, temperature, top_k, repetition_penalty,
227
- diversity_penalty, num_steps, clean_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ],
229
- outputs=[text_out],
230
  )
231
 
232
 
233
  if __name__ == "__main__":
234
- port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
235
- demo.launch(server_name="0.0.0.0", server_port=port, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import copy
2
+ import json
3
  import os
4
+ import subprocess
5
+ import sys
6
+ from datetime import datetime
7
 
8
  import gradio as gr
9
  import torch
10
 
11
  from config import CONFIG
12
+ from inference import _resolve_device, load_model, run_inference, _decode_clean, _decode_with_cleanup
13
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
14
 
15
 
16
+ RESULTS_DIR = "generated_results"
17
+ DEFAULT_ANALYSIS_OUT = "analysis/outputs"
18
+ os.makedirs(RESULTS_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
+ def discover_checkpoints():
22
+ found = []
23
  for root in ("ablation_results", "results7", "results"):
24
  if not os.path.isdir(root):
25
  continue
26
+ for entry in sorted(os.listdir(root)):
27
+ ckpt = os.path.join(root, entry, "best_model.pt")
28
+ if not os.path.exists(ckpt):
29
+ continue
30
+ found.append(
31
+ {
32
+ "label": f"{entry} [{root}]",
33
+ "path": ckpt,
34
+ "experiment": entry,
35
+ "root": root,
36
+ }
37
+ )
38
  return found
39
 
40
 
41
+ def checkpoint_map():
42
+ return {item["label"]: item for item in discover_checkpoints()}
 
 
 
 
 
43
 
 
 
44
 
45
+ def default_checkpoint_label():
46
+ cps = discover_checkpoints()
47
+ if not cps:
48
+ return None
49
+ for item in cps:
50
+ if item["path"].endswith("ablation_results/T4/best_model.pt"):
51
+ return item["label"]
52
+ return cps[0]["label"]
53
 
54
 
55
+ def infer_model_type(experiment_name: str, root: str = "") -> str:
56
+ if root == "ablation_results":
57
+ return "d3pm_cross_attention"
58
+ if experiment_name.startswith("d3pm_cross_attention"):
59
+ return "d3pm_cross_attention"
60
+ if experiment_name.startswith("d3pm_encoder_decoder"):
61
  return "d3pm_encoder_decoder"
62
+ if experiment_name.startswith("baseline_cross_attention"):
63
  return "baseline_cross_attention"
64
+ if experiment_name.startswith("baseline_encoder_decoder"):
65
  return "baseline_encoder_decoder"
66
+ return CONFIG["model_type"]
67
 
68
 
69
+ def infer_include_negative(experiment_name: str, root: str = "") -> bool:
70
+ if root == "ablation_results":
71
+ return False
72
+ if "_neg_True" in experiment_name:
73
  return True
74
+ if "_neg_False" in experiment_name:
75
  return False
76
  return CONFIG["data"]["include_negative_examples"]
77
 
78
 
79
+ def build_runtime_cfg(ckpt_path: str):
80
+ experiment = os.path.basename(os.path.dirname(ckpt_path))
81
+ root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
82
+ cfg = copy.deepcopy(CONFIG)
83
+ cfg["model_type"] = infer_model_type(experiment, root=root)
84
+ cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
85
+
86
+ if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
87
+ t_val = int(experiment[1:])
88
+ cfg["model"]["diffusion_steps"] = t_val
89
+ cfg["inference"]["num_steps"] = t_val
90
+
91
+ device = _resolve_device(cfg.get("training", {}).get("device", "cpu"))
92
+ return cfg, device, experiment
93
+
94
+
95
+ def _build_tokenizers(cfg):
96
+ src_tok = SanskritSourceTokenizer(
97
+ vocab_size=cfg["model"].get("src_vocab_size", 16000),
98
+ max_len=cfg["model"]["max_seq_len"],
99
+ )
100
+ tgt_tok = SanskritTargetTokenizer(
101
+ vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
102
+ max_len=cfg["model"]["max_seq_len"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
+ return src_tok, tgt_tok
105
+
106
+
107
+ def load_selected_model(checkpoint_label):
108
+ mapping = checkpoint_map()
109
+ if not mapping:
110
+ raise gr.Error("No checkpoints found. Add models under ablation_results/ or results*/.")
111
+ if not checkpoint_label:
112
+ checkpoint_label = default_checkpoint_label()
113
+ if checkpoint_label not in mapping:
114
+ raise gr.Error("Selected checkpoint not found. Click refresh.")
115
+
116
+ ckpt_path = mapping[checkpoint_label]["path"]
117
+ cfg, device, experiment = build_runtime_cfg(ckpt_path)
118
+ model, cfg = load_model(ckpt_path, cfg, device)
119
+ src_tok, tgt_tok = _build_tokenizers(cfg)
120
+
121
+ bundle = {
122
+ "ckpt_path": ckpt_path,
123
+ "experiment": experiment,
124
+ "device": str(device),
125
+ "cfg": cfg,
126
+ "model": model,
127
+ "src_tok": src_tok,
128
+ "tgt_tok": tgt_tok,
129
+ }
130
+ model_info = {
131
+ "checkpoint": ckpt_path,
132
+ "experiment": experiment,
133
+ "model_type": cfg["model_type"],
134
+ "include_negatives": cfg["data"]["include_negative_examples"],
135
+ "device": str(device),
136
+ "max_seq_len": cfg["model"]["max_seq_len"],
137
+ "diffusion_steps": cfg["model"]["diffusion_steps"],
138
+ "inference_steps": cfg["inference"]["num_steps"],
139
+ "d_model": cfg["model"]["d_model"],
140
+ "n_layers": cfg["model"]["n_layers"],
141
+ "n_heads": cfg["model"]["n_heads"],
142
+ }
143
+ status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
144
+ suggested_out = os.path.join("analysis", "outputs_ui", experiment)
145
+ return bundle, status, model_info, cfg["inference"]["num_steps"], suggested_out
146
+
147
+
148
+ def apply_preset(preset_name):
149
+ presets = {
150
+ "Manual": (0.70, 40, 1.20, 0.0),
151
+ "Literal": (0.60, 20, 1.25, 0.0),
152
+ "Balanced": (0.70, 40, 1.20, 0.0),
153
+ "Creative": (0.90, 80, 1.05, 0.2),
154
+ }
155
+ return presets.get(preset_name, presets["Balanced"])
156
+
157
+
158
+ def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
159
+ text = " ".join(text.split())
160
+ if not text:
161
+ return text
162
+ tokens = text.split()
163
+ cleaned = []
164
+ prev = None
165
+ run = 0
166
+ for tok in tokens:
167
+ if tok == prev:
168
+ run += 1
169
+ else:
170
+ prev = tok
171
+ run = 1
172
+ if run <= max_consecutive:
173
+ cleaned.append(tok)
174
+ out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
175
+ return " ".join(out.split())
176
+
177
+
178
+ def save_generation(experiment, record):
179
+ ts = datetime.now().strftime("%Y%m%d")
180
+ path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
181
+ existing = []
182
+ if os.path.exists(path):
183
+ with open(path, "r", encoding="utf-8") as f:
184
+ existing = json.load(f)
185
+ existing.append(record)
186
+ with open(path, "w", encoding="utf-8") as f:
187
+ json.dump(existing, f, ensure_ascii=False, indent=2)
188
+ return path
189
+
190
+
191
+ def generate_from_ui(
192
+ model_bundle,
193
+ input_text,
194
+ temperature,
195
+ top_k,
196
+ repetition_penalty,
197
+ diversity_penalty,
198
+ num_steps,
199
+ clean_output,
200
+ ):
201
+ if not model_bundle:
202
+ raise gr.Error("Load a model first.")
203
+ if not input_text.strip():
204
+ raise gr.Error("Enter input text first.")
205
+
206
+ cfg = copy.deepcopy(model_bundle["cfg"])
207
  cfg["inference"]["temperature"] = float(temperature)
208
  cfg["inference"]["top_k"] = int(top_k)
209
  cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
210
  cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
211
  cfg["inference"]["num_steps"] = int(num_steps)
212
 
213
+ src_tok = model_bundle["src_tok"]
214
+ tgt_tok = model_bundle["tgt_tok"]
215
+ device = torch.device(model_bundle["device"])
 
216
 
217
+ input_ids = torch.tensor(
218
+ [src_tok.encode(input_text.strip())[:cfg["model"]["max_seq_len"]]],
219
+ dtype=torch.long,
220
+ device=device,
221
+ )
222
+ out = run_inference(model_bundle["model"], input_ids, cfg)
223
+
224
+ # Use the exact inference decode/cleanup logic for parity with inference.py
225
+ raw_output_text = _decode_clean(tgt_tok, out[0].tolist())
226
  if clean_output:
227
+ output_text = _decode_with_cleanup(
228
+ tgt_tok, out[0].tolist(), input_text.strip(), cfg["inference"]
229
+ )
230
+ else:
231
+ output_text = raw_output_text
232
+ if not output_text:
233
+ output_text = "(empty output)"
234
+
235
+ record = {
236
+ "timestamp": datetime.now().isoformat(timespec="seconds"),
237
+ "experiment": model_bundle["experiment"],
238
+ "checkpoint": model_bundle["ckpt_path"],
239
+ "input_text": input_text,
240
+ "raw_output_text": raw_output_text,
241
+ "output_text": output_text,
242
+ "temperature": float(temperature),
243
+ "top_k": int(top_k),
244
+ "repetition_penalty": float(repetition_penalty),
245
+ "diversity_penalty": float(diversity_penalty),
246
+ "num_steps": int(num_steps),
247
+ "clean_output": bool(clean_output),
248
+ }
249
+ log_path = save_generation(model_bundle["experiment"], record)
250
+ status = f"Inference done. Saved: `{log_path}`"
251
+ return output_text, status, record
252
+
253
+
254
+ def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
255
+ os.makedirs(output_dir, exist_ok=True)
256
+ cmd = [
257
+ sys.executable,
258
+ "analysis/run_analysis.py",
259
+ "--task",
260
+ str(task),
261
+ "--checkpoint",
262
+ ckpt_path,
263
+ "--output_dir",
264
+ output_dir,
265
+ ]
266
+ if str(task) == "2" or str(task) == "all":
267
+ cmd.extend(["--input", input_text])
268
+ if str(task) == "4":
269
+ cmd.extend(["--phase", phase])
270
+
271
+ env = os.environ.copy()
272
+ env.setdefault("HF_HOME", "/tmp/hf_home")
273
+ env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
274
+ env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
275
+
276
+ proc = subprocess.run(cmd, capture_output=True, text=True, env=env)
277
+ log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}"
278
+ return proc.returncode, log
279
+
280
+
281
+ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
282
+ if not model_bundle:
283
+ raise gr.Error("Load a model first.")
284
+ code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
285
+ status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})."
286
+ return status, log
287
+
288
+
289
+ def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
290
+ if not model_bundle:
291
+ raise gr.Error("Load a model first.")
292
+ logs = []
293
+ failures = 0
294
+ for task in ["1", "2", "3", "4", "5"]:
295
+ code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
296
+ logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
297
+ if code != 0:
298
+ failures += 1
299
+ status = f"Run-all finished with {failures} failed task(s)." if failures else "All 5 tasks completed."
300
+ return status, "".join(logs)
301
+
302
+
303
+ def _read_text(path):
304
+ if not os.path.exists(path):
305
+ return "Not found."
306
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
307
+ return f.read()
308
+
309
+
310
+ def _img_or_none(path):
311
+ return path if os.path.exists(path) else None
312
+
313
+
314
+ def refresh_task_outputs(output_dir):
315
+ task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt"))
316
+ task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt"))
317
+ task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt"))
318
+ task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt"))
319
+
320
+ task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
321
+ task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
322
+ task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
323
+ task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
324
+ if task4_plot is None:
325
+ task4_plot = _img_or_none(os.path.join(output_dir, "task4_3d.png"))
326
+ return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot
327
+
328
+
329
+ CUSTOM_CSS = """
330
+ :root {
331
+ --bg1: #f5fbff;
332
+ --bg2: #f2f7ef;
333
+ --card: #ffffff;
334
+ --line: #d9e6f2;
335
+ --ink: #163048;
336
+ }
337
+ .gradio-container {
338
+ background: linear-gradient(130deg, var(--bg1), var(--bg2));
339
+ color: var(--ink);
340
+ }
341
+ #hero {
342
+ background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%);
343
+ border: 1px solid #cfe0f1;
344
+ border-radius: 16px;
345
+ padding: 18px 20px;
346
+ }
347
+ .panel {
348
+ background: var(--card);
349
+ border: 1px solid var(--line);
350
+ border-radius: 14px;
351
+ }
352
+ """
353
 
354
 
355
+ with gr.Blocks(title="Sanskrit Diffusion Client Demo") as demo:
356
  model_state = gr.State(None)
357
+
358
  gr.Markdown(
359
  """
360
+ <div id="hero">
361
+ <h1 style="margin:0;">Sanskrit Diffusion Client Demo</h1>
362
+ <p style="margin:.5rem 0 0 0;">
363
+ Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.
364
+ </p>
365
+ </div>
366
  """
367
  )
368
 
 
 
 
 
 
 
 
 
 
 
 
369
  with gr.Row():
370
+ with gr.Column(scale=2, elem_classes=["panel"]):
371
+ checkpoint_dropdown = gr.Dropdown(
372
+ label="Model Checkpoint",
373
+ choices=list(checkpoint_map().keys()),
374
+ value=default_checkpoint_label(),
375
+ interactive=True,
376
+ )
377
+ with gr.Column(scale=1, elem_classes=["panel"]):
378
+ refresh_btn = gr.Button("Refresh Models")
379
+ load_btn = gr.Button("Load Selected Model", variant="primary")
380
+
381
+ init_msg = "Select a model and load." if checkpoint_map() else "No checkpoints found in ablation_results/ or results*/."
382
+ load_status = gr.Markdown(init_msg)
383
+ model_info = gr.JSON(label="Loaded Model Details")
384
+
385
+ with gr.Tabs():
386
+ with gr.Tab("1) Task Runner"):
387
+ with gr.Row():
388
+ with gr.Column(scale=2):
389
+ analysis_output_dir = gr.Textbox(
390
+ label="Analysis Output Directory",
391
+ value=DEFAULT_ANALYSIS_OUT,
392
+ )
393
+ analysis_input = gr.Textbox(
394
+ label="Task 2 Input Text",
395
+ value="dharmo rakṣati rakṣitaḥ",
396
+ lines=2,
397
+ )
398
+ with gr.Column(scale=1):
399
+ task4_phase = gr.Dropdown(
400
+ choices=["analyze", "generate_configs"],
401
+ value="analyze",
402
+ label="Task 4 Phase",
403
+ )
404
+ run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
405
+
406
+ with gr.Row():
407
+ task_choice = gr.Dropdown(
408
+ choices=["1", "2", "3", "4", "5"],
409
+ value="1",
410
+ label="Single Task",
411
+ )
412
+ run_single_btn = gr.Button("Run Selected Task")
413
+ refresh_outputs_btn = gr.Button("Refresh Output Viewer")
414
+
415
+ task_run_status = gr.Markdown("")
416
+ task_run_log = gr.Textbox(label="Task Execution Log", lines=18, interactive=False)
417
+
418
+ with gr.Accordion("Task Outputs Viewer", open=True):
419
+ task1_box = gr.Textbox(label="Task 1 Report", lines=10, interactive=False)
420
+ task2_box = gr.Textbox(label="Task 2 Report", lines=10, interactive=False)
421
+ with gr.Row():
422
+ task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
423
+ task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
424
+ task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
425
+ task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
426
+ task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
427
+ task4_img = gr.Image(label="Task4 3D Ablation Plot", type="filepath")
428
+
429
+ with gr.Tab("2) Inference Playground"):
430
+ with gr.Row():
431
+ with gr.Column(scale=2):
432
+ input_text = gr.Textbox(
433
+ label="Input (Roman / IAST)",
434
+ lines=4,
435
+ value="dharmo rakṣati rakṣitaḥ",
436
+ )
437
+ output_text = gr.Textbox(
438
+ label="Output (Devanagari)",
439
+ lines=7,
440
+ interactive=False,
441
+ )
442
+ run_status = gr.Markdown("")
443
+ run_record = gr.JSON(label="Inference Metadata")
444
+ with gr.Column(scale=1, elem_classes=["panel"]):
445
+ preset = gr.Radio(["Manual", "Literal", "Balanced", "Creative"], value="Balanced", label="Preset")
446
+ temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
447
+ top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
448
+ repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
449
+ diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
450
+ num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
451
+ clean_output = gr.Checkbox(value=True, label="Clean Output")
452
+ generate_btn = gr.Button("Generate", variant="primary")
453
+
454
+ gr.Examples(
455
+ examples=[
456
+ ["dharmo rakṣati rakṣitaḥ"],
457
+ ["satyameva jayate"],
458
+ ["yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ"],
459
+ ],
460
+ inputs=[input_text],
461
+ )
462
+
463
+ def refresh_checkpoints():
464
+ choices = list(checkpoint_map().keys())
465
+ value = default_checkpoint_label() if choices else None
466
+ msg = f"Found {len(choices)} checkpoint(s)." if choices else "No checkpoints found."
467
+ return gr.Dropdown(choices=choices, value=value), msg
468
+
469
+ def auto_load_default():
470
+ choices = list(checkpoint_map().keys())
471
+ if not choices:
472
+ return None, "No checkpoints found.", {}, 64, DEFAULT_ANALYSIS_OUT
473
+ return load_selected_model(default_checkpoint_label())
474
+
475
+ refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown, load_status])
476
+ load_btn.click(
477
+ fn=load_selected_model,
478
+ inputs=[checkpoint_dropdown],
479
+ outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
480
+ )
481
 
482
+ preset.change(
483
+ fn=apply_preset,
484
+ inputs=[preset],
485
+ outputs=[temperature, top_k, repetition_penalty, diversity_penalty],
486
+ )
487
 
 
488
  generate_btn.click(
489
+ fn=generate_from_ui,
490
  inputs=[
491
+ model_state,
492
+ input_text,
493
+ temperature,
494
+ top_k,
495
+ repetition_penalty,
496
+ diversity_penalty,
497
+ num_steps,
498
+ clean_output,
499
  ],
500
+ outputs=[output_text, run_status, run_record],
501
  )
502
+ input_text.submit(
503
+ fn=generate_from_ui,
504
  inputs=[
505
+ model_state,
506
+ input_text,
507
+ temperature,
508
+ top_k,
509
+ repetition_penalty,
510
+ diversity_penalty,
511
+ num_steps,
512
+ clean_output,
513
+ ],
514
+ outputs=[output_text, run_status, run_record],
515
+ )
516
+
517
+ run_single_btn.click(
518
+ fn=run_single_task,
519
+ inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
520
+ outputs=[task_run_status, task_run_log],
521
+ )
522
+ run_all_btn.click(
523
+ fn=run_all_tasks,
524
+ inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
525
+ outputs=[task_run_status, task_run_log],
526
+ )
527
+ refresh_outputs_btn.click(
528
+ fn=refresh_task_outputs,
529
+ inputs=[analysis_output_dir],
530
+ outputs=[
531
+ task1_box,
532
+ task2_box,
533
+ task2_drift_img,
534
+ task2_attn_img,
535
+ task3_box,
536
+ task3_img,
537
+ task5_box,
538
+ task4_img,
539
+ ],
540
+ )
541
+ demo.load(
542
+ fn=auto_load_default,
543
+ outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
544
+ )
545
+ demo.load(
546
+ fn=refresh_task_outputs,
547
+ inputs=[analysis_output_dir],
548
+ outputs=[
549
+ task1_box,
550
+ task2_box,
551
+ task2_drift_img,
552
+ task2_attn_img,
553
+ task3_box,
554
+ task3_img,
555
+ task5_box,
556
+ task4_img,
557
  ],
 
558
  )
559
 
560
 
561
  if __name__ == "__main__":
562
+ port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
563
+ demo.launch(server_name="127.0.0.1", server_port=port, share=False, css=CUSTOM_CSS)