Toadoum commited on
Commit
08d20bc
·
verified ·
1 Parent(s): 707226a

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +509 -0
  2. dialogue.py +237 -0
  3. nlu.py +310 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlotWeaver Voice Agent — HuggingFace Space
3
+ ============================================
4
+ Gradio app demonstrating a Hausa-first conversational AI for
5
+ African banks, telecoms, and delivery services.
6
+
7
+ Pipeline: ASR (Whisper-small) → NLU (rule-based) → Dialogue FSM →
8
+ TTS (facebook/mms-tts-hau).
9
+
10
+ Runs on CPU. First turn triggers model download (~500MB), subsequent turns
11
+ are ~2-4s end-to-end.
12
+ """
13
+ from __future__ import annotations
14
+ import time
15
+ import uuid
16
+ import html as html_lib
17
+ from typing import Optional
18
+
19
+ import gradio as gr
20
+ import numpy as np
21
+ import torch
22
+ from transformers import (
23
+ VitsModel, AutoTokenizer,
24
+ WhisperProcessor, WhisperForConditionalGeneration,
25
+ )
26
+
27
+ from dialogue import (
28
+ DialogueState, SCENARIOS,
29
+ get_prompt, get_expected_slot, transition,
30
+ )
31
+ from nlu import parse as nlu_parse
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Model loading (lazy, cached)
36
+ # ---------------------------------------------------------------------------
37
+ _asr_model = None
38
+ _asr_processor = None
39
+ _tts_model = None
40
+ _tts_tokenizer = None
41
+
42
+
43
+ def load_asr():
44
+ global _asr_model, _asr_processor
45
+ if _asr_model is None:
46
+ print("Loading Whisper-small…")
47
+ _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
48
+ _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
49
+ _asr_model.eval()
50
+ print("Whisper-small ready.")
51
+ return _asr_model, _asr_processor
52
+
53
+
54
+ def load_tts():
55
+ global _tts_model, _tts_tokenizer
56
+ if _tts_model is None:
57
+ print("Loading MMS-TTS Hausa…")
58
+ _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau")
59
+ _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
60
+ _tts_model.eval()
61
+ print("MMS-TTS Hausa ready.")
62
+ return _tts_model, _tts_tokenizer
63
+
64
+
65
+ def transcribe_hausa(audio_tuple) -> str:
66
+ """audio_tuple is (sample_rate, np.ndarray) from Gradio."""
67
+ if audio_tuple is None:
68
+ return ""
69
+ sample_rate, audio_array = audio_tuple
70
+ if audio_array is None or len(audio_array) == 0:
71
+ return ""
72
+ # Convert to float32 mono
73
+ if audio_array.dtype != np.float32:
74
+ audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
75
+ if audio_array.ndim > 1:
76
+ audio_array = audio_array.mean(axis=1)
77
+ # Cap at 30s — Whisper-small is trained on 30s chunks; longer audio
78
+ # would need windowing which slows the demo
79
+ max_samples = sample_rate * 30
80
+ if len(audio_array) > max_samples:
81
+ audio_array = audio_array[:max_samples]
82
+ # Resample to 16 kHz
83
+ if sample_rate != 16000:
84
+ import scipy.signal
85
+ num_samples = int(len(audio_array) * 16000 / sample_rate)
86
+ audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32)
87
+
88
+ model, processor = load_asr()
89
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
90
+ forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")
91
+ with torch.no_grad():
92
+ ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128)
93
+ text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
94
+ return text
95
+
96
+
97
+ def synthesize_hausa(text: str) -> Optional[tuple]:
98
+ """Return (sample_rate, np.ndarray) or None."""
99
+ if not text.strip():
100
+ return None
101
+ model, tokenizer = load_tts()
102
+ inputs = tokenizer(text, return_tensors="pt")
103
+ with torch.no_grad():
104
+ out = model(**inputs).waveform
105
+ audio = out.squeeze().cpu().numpy().astype(np.float32)
106
+ return (model.config.sampling_rate, audio)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Core turn handler
111
+ # ---------------------------------------------------------------------------
112
+ def run_turn(user_text: str, session: dict, trace: list, asr_ms: int = 0) -> tuple:
113
+ """
114
+ Executes one turn. Returns (bot_prompt_dict, updated_session, trace, tts_audio).
115
+ `session` is a serialized dict stored in gr.State.
116
+ """
117
+ state = DialogueState.from_dict(session) if session else None
118
+ if state is None:
119
+ state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank")
120
+
121
+ turn_trace = []
122
+ if asr_ms:
123
+ turn_trace.append({"stage": "asr (whisper-small)", "ms": asr_ms,
124
+ "detail": f'→ "{user_text}"'})
125
+
126
+ t0 = time.time()
127
+ expected = get_expected_slot(state.vertical, state.current_state)
128
+ intent, entities, nlu_source = nlu_parse(user_text, expected)
129
+ nlu_stage_label = {
130
+ "rule": "nlu (rule-based)",
131
+ "llm": "nlu (qwen2.5-1.5b)",
132
+ "rule_fallback": "nlu (rule + llm fallback)",
133
+ }.get(nlu_source, "nlu")
134
+ turn_trace.append({
135
+ "stage": nlu_stage_label,
136
+ "ms": max(1, int((time.time() - t0) * 1000)),
137
+ "detail": f"intent={intent} entities={entities}",
138
+ })
139
+
140
+ t1 = time.time()
141
+ prev_state = state.current_state
142
+ state = transition(state, intent, entities)
143
+ turn_trace.append({
144
+ "stage": "dialogue_manager",
145
+ "ms": max(1, int((time.time() - t1) * 1000)),
146
+ "detail": f"{prev_state} → {state.current_state}",
147
+ })
148
+
149
+ t2 = time.time()
150
+ prompt = get_prompt(state.vertical, state.current_state)
151
+ turn_trace.append({"stage": "response_gen", "ms": max(1, int((time.time() - t2) * 1000))})
152
+
153
+ t3 = time.time()
154
+ audio = synthesize_hausa(prompt["ha"])
155
+ turn_trace.append({"stage": "tts (mms-tts-hau)", "ms": int((time.time() - t3) * 1000)})
156
+
157
+ state.history.append({"role": "user", "text": user_text})
158
+ state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]})
159
+
160
+ return prompt, state.to_dict(), turn_trace, audio
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # WhatsApp-style HTML renderer
165
+ # ---------------------------------------------------------------------------
166
+ def render_whatsapp(session: dict, pending_user: Optional[str] = None,
167
+ pending_is_voice: bool = False) -> str:
168
+ vertical = session.get("vertical", "bank") if session else "bank"
169
+ name = SCENARIOS[vertical]["name"]
170
+ avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical]
171
+ escalated = session.get("escalate_to_human", False) if session else False
172
+
173
+ bubbles = []
174
+ history = session.get("history", []) if session else []
175
+ for msg in history:
176
+ if msg["role"] == "user":
177
+ is_voice = msg.get("is_voice", False)
178
+ bubbles.append(_user_bubble(msg["text"], is_voice))
179
+ else:
180
+ bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", "")))
181
+ if pending_user:
182
+ bubbles.append(_user_bubble(pending_user, pending_is_voice))
183
+
184
+ banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>'
185
+ if escalated else "")
186
+
187
+ return f"""
188
+ <div class="pw-phone">
189
+ <div class="pw-ph-header">
190
+ <div class="pw-ph-avatar">{avatar}</div>
191
+ <div>
192
+ <div class="pw-ph-name">{html_lib.escape(name)}</div>
193
+ <div class="pw-ph-status">online • voice agent</div>
194
+ </div>
195
+ </div>
196
+ <div class="pw-ph-messages">
197
+ {banner}
198
+ {"".join(bubbles) if bubbles else '<div style="text-align:center; color:#667781; font-size:12px; padding:40px 0;">Waiting for first message…</div>'}
199
+ </div>
200
+ </div>
201
+ <style>
202
+ .pw-phone {{ max-width: 440px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 520px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }}
203
+ .pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }}
204
+ .pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }}
205
+ .pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }}
206
+ .pw-ph-status {{ font-size: 11px; color: #D4EDE8; }}
207
+ .pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 460px; overflow-y: auto; min-height: 400px; }}
208
+ .pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }}
209
+ .pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }}
210
+ .pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }}
211
+ .pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }}
212
+ .pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }}
213
+ .pw-voice-row {{ display: flex; align-items: center; gap: 8px; }}
214
+ .pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }}
215
+ .pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }}
216
+ .pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }}
217
+ .pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }}
218
+ </style>
219
+ """
220
+
221
+
222
+ def _now() -> str:
223
+ return time.strftime("%H:%M")
224
+
225
+
226
+ def _user_bubble(text: str, is_voice: bool) -> str:
227
+ text_safe = html_lib.escape(text)
228
+ if is_voice:
229
+ bars = "".join(
230
+ f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>'
231
+ for i in range(20)
232
+ )
233
+ return f'''<div class="pw-b user">
234
+ <div class="pw-voice-row">
235
+ <div class="pw-voice-icon">▶</div>
236
+ <div class="pw-voice-bars">{bars}</div>
237
+ </div>
238
+ <div style="font-size:12px; color:#667781; margin-top:3px;">"{text_safe}"</div>
239
+ <div class="pw-b-meta">{_now()} ✓✓</div>
240
+ </div>'''
241
+ return f'<div class="pw-b user">{text_safe}<div class="pw-b-meta">{_now()} ✓✓</div></div>'
242
+
243
+
244
+ def _bot_bubble(text_ha: str, text_en: str) -> str:
245
+ ha_safe = html_lib.escape(text_ha)
246
+ en_safe = html_lib.escape(text_en)
247
+ return f'''<div class="pw-b bot">
248
+ <div>{ha_safe}</div>
249
+ <div class="pw-b-trans">{en_safe}</div>
250
+ <div class="pw-b-meta">{_now()} ✓✓</div>
251
+ </div>'''
252
+
253
+
254
+ def render_trace(trace: list) -> str:
255
+ if not trace:
256
+ return '<div style="color:#888; font-size:13px;">Send a message to see the pipeline trace.</div>'
257
+ rows = []
258
+ for r in trace:
259
+ row = f'<div style="display:flex; justify-content:space-between; padding:5px 0; border-bottom:1px solid #eee;"><span style="color:#5f5e5a;">{html_lib.escape(r["stage"])}</span><span style="color:#0C447C; font-weight:500;">{r["ms"]}ms</span></div>'
260
+ rows.append(row)
261
+ if r.get("detail"):
262
+ rows.append(f'<div style="font-size:11px; color:#888; padding:0 0 5px; font-family:monospace;">{html_lib.escape(str(r["detail"]))}</div>')
263
+ return f'<div style="font-family:monospace; font-size:12px;">{"".join(rows)}</div>'
264
+
265
+
266
+ def render_metrics(session: dict) -> str:
267
+ if not session:
268
+ return ""
269
+ sid = session.get("session_id", "—")
270
+ turn = session.get("turn_count", 0)
271
+ state = session.get("current_state", "greeting")
272
+ slots = session.get("slots", {})
273
+ slots_html = ", ".join(f"<code>{k}={v}</code>" for k, v in slots.items()) or "—"
274
+ return f'''
275
+ <div style="display:grid; grid-template-columns:1fr 1fr; gap:8px; font-size:13px;">
276
+ <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Session</div><div style="font-family:monospace;">{sid}</div></div>
277
+ <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Turn</div><div style="font-weight:500;">{turn}</div></div>
278
+ <div><div style="color:#888; font-size:11px; text-transform:uppercase;">State</div><div style="font-family:monospace;">{state}</div></div>
279
+ <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Slots</div><div>{slots_html}</div></div>
280
+ </div>'''
281
+
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Gradio event handlers
285
+ # ---------------------------------------------------------------------------
286
+ def on_vertical_change(vertical: str, synth_greeting: bool = False):
287
+ """Reset session when vertical changes. TTS the greeting only on first real
288
+ user interaction — keeps initial page load fast (avoids MMS-TTS cold-start)."""
289
+ state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical)
290
+ greet = get_prompt(vertical, "greeting")
291
+ state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]})
292
+ session = state.to_dict()
293
+ audio = None
294
+ if synth_greeting:
295
+ try:
296
+ audio = synthesize_hausa(greet["ha"])
297
+ except Exception as e:
298
+ print(f"TTS failed on greeting: {e}")
299
+ return (
300
+ session,
301
+ render_whatsapp(session),
302
+ render_trace([]),
303
+ render_metrics(session),
304
+ audio,
305
+ )
306
+
307
+
308
+ def on_text_submit(text: str, session: dict):
309
+ if not text or not text.strip():
310
+ return session, render_whatsapp(session), render_trace([]), render_metrics(session), None, ""
311
+ prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=0)
312
+ return (
313
+ new_session,
314
+ render_whatsapp(new_session),
315
+ render_trace(trace),
316
+ render_metrics(new_session),
317
+ audio,
318
+ "", # clear input
319
+ )
320
+
321
+
322
+ def on_audio_submit(audio_data, session: dict):
323
+ if audio_data is None:
324
+ return session, render_whatsapp(session), render_trace([]), render_metrics(session), None
325
+ t0 = time.time()
326
+ try:
327
+ text = transcribe_hausa(audio_data)
328
+ except Exception as e:
329
+ print(f"ASR failed: {e}")
330
+ return session, render_whatsapp(session), render_trace([{"stage": "asr error", "ms": 0, "detail": str(e)}]), render_metrics(session), None
331
+ asr_ms = int((time.time() - t0) * 1000)
332
+ if not text:
333
+ return session, render_whatsapp(session), render_trace([{"stage": "asr", "ms": asr_ms, "detail": "(no speech detected)"}]), render_metrics(session), None
334
+ # Mark last user message as voice after appending
335
+ prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=asr_ms)
336
+ # Tag the last user entry as voice
337
+ if new_session.get("history"):
338
+ for i in range(len(new_session["history"]) - 1, -1, -1):
339
+ if new_session["history"][i]["role"] == "user":
340
+ new_session["history"][i]["is_voice"] = True
341
+ break
342
+ return (
343
+ new_session,
344
+ render_whatsapp(new_session),
345
+ render_trace(trace),
346
+ render_metrics(new_session),
347
+ audio,
348
+ )
349
+
350
+
351
+ def on_reset(session: dict):
352
+ vertical = session.get("vertical", "bank") if session else "bank"
353
+ return on_vertical_change(vertical)
354
+
355
+
356
+ def on_escalate(session: dict):
357
+ return on_text_submit("Ina son wakili mutum", session)
358
+
359
+
360
+ # ---------------------------------------------------------------------------
361
+ # Preset phrases for quick-click demo
362
+ # ---------------------------------------------------------------------------
363
+ PRESETS = {
364
+ "bank": ["duba ma'auni", "toshe kati", "canjin kuɗi", "1234", "Aisha", "dubu biyar", "i"],
365
+ "telecom": ["saya airtime", "saya bundle", "korafi", "1000", "rana", "Intanet bai aiki"],
366
+ "ecommerce": ["bincika oda", "sake tsara", "mayar da kaya", "10234", "jumma'a", "Ya lalace"],
367
+ }
368
+
369
+
370
+ # ---------------------------------------------------------------------------
371
+ # Gradio UI
372
+ # ---------------------------------------------------------------------------
373
+ CUSTOM_CSS = """
374
+ .gradio-container { max-width: 1200px !important; }
375
+ #vertical-selector { background: #fff; border-radius: 10px; padding: 12px; }
376
+ #whatsapp-html { background: #f5f4ef; border-radius: 12px; padding: 20px; }
377
+ #trace-box, #metrics-box { background: #fff; border-radius: 10px; padding: 12px; border: 1px solid #e5e5e5; }
378
+ h1 { font-size: 22px !important; font-weight: 500 !important; }
379
+ .header-sub { color: #5f5e5a; font-size: 14px; margin-top: -8px; margin-bottom: 16px; }
380
+ """
381
+
382
+ with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo:
383
+ gr.HTML("""
384
+ <h1 style="margin-bottom:4px;">PlotWeaver Voice Agent</h1>
385
+ <p class="header-sub">Hausa-first conversational AI for African banks, telecoms, and delivery services. Real Whisper-small ASR and MMS-TTS Hausa running on CPU.</p>
386
+ """)
387
+
388
+ session_state = gr.State({})
389
+
390
+ with gr.Row():
391
+ # Left column: controls + trace
392
+ with gr.Column(scale=1):
393
+ gr.Markdown("### Select vertical")
394
+ vertical_radio = gr.Radio(
395
+ choices=[("PlotWeaver Bank", "bank"),
396
+ ("PlotWeaver Telecom", "telecom"),
397
+ ("PlotWeaver Delivery", "ecommerce")],
398
+ value="bank",
399
+ label="",
400
+ elem_id="vertical-selector",
401
+ )
402
+
403
+ with gr.Row():
404
+ reset_btn = gr.Button("Reset session", size="sm")
405
+ escalate_btn = gr.Button("Force escalate", size="sm")
406
+
407
+ gr.Markdown("### Session metrics")
408
+ metrics_html = gr.HTML(elem_id="metrics-box")
409
+
410
+ gr.Markdown("### Pipeline trace (last turn)")
411
+ trace_html = gr.HTML(elem_id="trace-box")
412
+
413
+ # Middle column: WhatsApp mockup
414
+ with gr.Column(scale=2):
415
+ whatsapp_html = gr.HTML(elem_id="whatsapp-html")
416
+
417
+ with gr.Row():
418
+ text_input = gr.Textbox(
419
+ placeholder="Type in Hausa… e.g. 'duba ma'auni'",
420
+ label="",
421
+ scale=4,
422
+ container=False,
423
+ )
424
+ send_btn = gr.Button("Send", scale=1, variant="primary")
425
+
426
+ gr.Markdown("**Or speak / upload audio in Hausa:**")
427
+ audio_input = gr.Audio(
428
+ sources=["microphone", "upload"],
429
+ type="numpy",
430
+ label="Record or upload a Hausa audio file (.wav, .mp3, .ogg)",
431
+ show_download_button=False,
432
+ )
433
+ with gr.Row():
434
+ transcribe_btn = gr.Button("Transcribe & send", variant="secondary", size="sm")
435
+ clear_audio_btn = gr.Button("Clear", size="sm")
436
+
437
+ bot_audio = gr.Audio(
438
+ label="Bot response (Hausa TTS)",
439
+ autoplay=True,
440
+ interactive=False,
441
+ )
442
+
443
+ # Preset quick-clicks
444
+ gr.Markdown("### Quick phrases (Hausa)")
445
+ preset_btns = []
446
+ with gr.Row():
447
+ for p in PRESETS["bank"]:
448
+ preset_btns.append(gr.Button(p, size="sm"))
449
+
450
+ # -----------------------------------------------------------------------
451
+ # Event wiring
452
+ # -----------------------------------------------------------------------
453
+ outputs = [session_state, whatsapp_html, trace_html, metrics_html, bot_audio]
454
+
455
+ demo.load(
456
+ fn=lambda: on_vertical_change("bank"),
457
+ outputs=outputs,
458
+ )
459
+
460
+ vertical_radio.change(
461
+ fn=on_vertical_change,
462
+ inputs=[vertical_radio],
463
+ outputs=outputs,
464
+ )
465
+
466
+ send_btn.click(
467
+ fn=on_text_submit,
468
+ inputs=[text_input, session_state],
469
+ outputs=outputs + [text_input],
470
+ )
471
+ text_input.submit(
472
+ fn=on_text_submit,
473
+ inputs=[text_input, session_state],
474
+ outputs=outputs + [text_input],
475
+ )
476
+
477
+ audio_input.stop_recording(
478
+ fn=on_audio_submit,
479
+ inputs=[audio_input, session_state],
480
+ outputs=outputs,
481
+ )
482
+ transcribe_btn.click(
483
+ fn=on_audio_submit,
484
+ inputs=[audio_input, session_state],
485
+ outputs=outputs,
486
+ )
487
+ clear_audio_btn.click(
488
+ fn=lambda: None,
489
+ outputs=[audio_input],
490
+ )
491
+
492
+ reset_btn.click(fn=on_reset, inputs=[session_state], outputs=outputs)
493
+ escalate_btn.click(
494
+ fn=on_escalate,
495
+ inputs=[session_state],
496
+ outputs=outputs + [text_input],
497
+ )
498
+
499
+ # Preset buttons submit their own text
500
+ for btn, phrase in zip(preset_btns, PRESETS["bank"]):
501
+ btn.click(
502
+ fn=lambda s, _phrase=phrase: on_text_submit(_phrase, s),
503
+ inputs=[session_state],
504
+ outputs=outputs + [text_input],
505
+ )
506
+
507
+
508
+ if __name__ == "__main__":
509
+ demo.launch()
dialogue.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlotWeaver Voice Agent — Dialogue Manager
3
+ ==========================================
4
+ FSM for multi-turn Hausa conversations across 3 verticals.
5
+ State lives in Gradio session state (dict) — no Redis needed in the Space.
6
+ """
7
+ from __future__ import annotations
8
+ from dataclasses import dataclass, field, asdict
9
+ from enum import Enum
10
+ from typing import Optional
11
+
12
+
13
+ class Vertical(str, Enum):
14
+ BANK = "bank"
15
+ TELECOM = "telecom"
16
+ ECOMMERCE = "ecommerce"
17
+
18
+
19
+ @dataclass
20
+ class DialogueState:
21
+ session_id: str
22
+ vertical: str
23
+ current_state: str = "greeting"
24
+ slots: dict = field(default_factory=dict)
25
+ turn_count: int = 0
26
+ escalate_to_human: bool = False
27
+ history: list = field(default_factory=list)
28
+
29
+ def to_dict(self):
30
+ return asdict(self)
31
+
32
+ @classmethod
33
+ def from_dict(cls, d):
34
+ return cls(**d) if d else None
35
+
36
+
37
+ SCENARIOS = {
38
+ "bank": {
39
+ "name": "PlotWeaver Bank",
40
+ "states": {
41
+ "greeting": {
42
+ "ha": "Sannu! Wannan shine mataimakin banki na PlotWeaver. Yaya zan taimake ka yau? Za ka iya ce 'duba ma'auni', 'toshe kati', ko 'canjin kuɗi'.",
43
+ "en": "Hello! This is the PlotWeaver banking assistant. How can I help you today? You can say 'check balance', 'block card', or 'transfer money'.",
44
+ "expects": "intent",
45
+ "transitions": {"check_balance": "ask_account_number", "block_card": "confirm_block_card", "transfer_money": "ask_recipient"},
46
+ },
47
+ "ask_account_number": {
48
+ "ha": "Don Allah ka faɗi lambobin ƙarshe huɗu na asusunka.",
49
+ "en": "Please say the last four digits of your account number.",
50
+ "expects": "digits",
51
+ "transitions": {"provide_digits": "return_balance"},
52
+ },
53
+ "return_balance": {
54
+ "ha": "Ma'aunin asusunka shine Naira dubu ɗari biyu da arba'in da biyar. Akwai wani abu?",
55
+ "en": "Your account balance is two hundred forty-five thousand Naira. Anything else?",
56
+ "expects": "yesno",
57
+ "transitions": {"yes": "greeting", "no": "exit"},
58
+ },
59
+ "confirm_block_card": {
60
+ "ha": "Don tabbatar, kana son toshe katinka? Ka ce 'i' ko 'a'a'.",
61
+ "en": "To confirm, you want to block your card? Say 'yes' or 'no'.",
62
+ "expects": "yesno",
63
+ "transitions": {"yes": "card_blocked", "no": "greeting"},
64
+ },
65
+ "card_blocked": {
66
+ "ha": "An toshe katinka. Sabon kati zai iso a cikin kwanaki uku zuwa biyar. Ana juya ka ga wakili don tabbatar.",
67
+ "en": "Your card is blocked. A new card will arrive in 3-5 days. Transferring you to an agent for confirmation.",
68
+ "expects": None, "terminal": True, "escalate": True,
69
+ },
70
+ "ask_recipient": {
71
+ "ha": "Zuwa wa kake son turawa? Ka faɗi sunan mai karɓa.",
72
+ "en": "Who do you want to transfer to? Say the recipient's name.",
73
+ "expects": "name",
74
+ "transitions": {"provide_name": "ask_amount"},
75
+ },
76
+ "ask_amount": {
77
+ "ha": "Nawa kake son turawa, a Naira?",
78
+ "en": "How much do you want to transfer, in Naira?",
79
+ "expects": "amount",
80
+ "transitions": {"provide_amount": "confirm_transfer"},
81
+ },
82
+ "confirm_transfer": {
83
+ "ha": "Zan tura kuɗin yanzu. Ka ce 'i' don ci gaba.",
84
+ "en": "I'll send the money now. Say 'yes' to continue.",
85
+ "expects": "yesno",
86
+ "transitions": {"yes": "transfer_done", "no": "greeting"},
87
+ },
88
+ "transfer_done": {
89
+ "ha": "An tura kuɗin. Godiya da zabar PlotWeaver Bank.",
90
+ "en": "Money sent. Thank you for choosing PlotWeaver Bank.",
91
+ "expects": None, "terminal": True,
92
+ },
93
+ },
94
+ },
95
+ "telecom": {
96
+ "name": "PlotWeaver Telecom",
97
+ "states": {
98
+ "greeting": {
99
+ "ha": "Sannu! Wannan shine PlotWeaver Telecom. Kana son 'saya airtime', 'saya bundle', ko 'yin korafi'?",
100
+ "en": "Hello! This is PlotWeaver Telecom. Would you like to 'buy airtime', 'buy bundle', or 'file a complaint'?",
101
+ "expects": "intent",
102
+ "transitions": {"buy_airtime": "ask_airtime_amount", "buy_bundle": "ask_bundle_type", "complaint": "ask_complaint"},
103
+ },
104
+ "ask_airtime_amount": {
105
+ "ha": "Nawa na airtime kake son saya? Misali, Naira ɗari ko dubu.",
106
+ "en": "How much airtime? For example 100 or 1000 Naira.",
107
+ "expects": "amount",
108
+ "transitions": {"provide_amount": "airtime_done"},
109
+ },
110
+ "airtime_done": {
111
+ "ha": "An kara airtime. Ma'aunin ka sabo shine Naira dubu ɗaya da ɗari biyar.",
112
+ "en": "Airtime loaded. Your new balance is 1500 Naira.",
113
+ "expects": None, "terminal": True,
114
+ },
115
+ "ask_bundle_type": {
116
+ "ha": "Wane irin bundle? Muna da 'rana', 'mako', ko 'wata'.",
117
+ "en": "Which bundle type? 'day', 'week', or 'month'.",
118
+ "expects": "bundle",
119
+ "transitions": {"provide_bundle": "bundle_done"},
120
+ },
121
+ "bundle_done": {
122
+ "ha": "An kunna bundle ɗinka. Za ka iya yin amfani da shi yanzu.",
123
+ "en": "Your bundle is active. You can use it now.",
124
+ "expects": None, "terminal": True,
125
+ },
126
+ "ask_complaint": {
127
+ "ha": "Me ya faru? Ka bayyana matsalar da kake fuskanta.",
128
+ "en": "What happened? Please describe the issue.",
129
+ "expects": "text",
130
+ "transitions": {"provide_text": "escalate"},
131
+ },
132
+ "escalate": {
133
+ "ha": "Nagode. Zan juya ka ga wakili na mutum yanzu.",
134
+ "en": "Thank you. I'll transfer you to a human agent now.",
135
+ "expects": None, "terminal": True, "escalate": True,
136
+ },
137
+ },
138
+ },
139
+ "ecommerce": {
140
+ "name": "PlotWeaver Delivery",
141
+ "states": {
142
+ "greeting": {
143
+ "ha": "Sannu! Wannan shine PlotWeaver Delivery. Kana son 'bincika oda', 'sake tsara lokaci', ko 'mayar da kaya'?",
144
+ "en": "Hello! This is PlotWeaver Delivery. Would you like to 'check order', 'reschedule', or 'return'?",
145
+ "expects": "intent",
146
+ "transitions": {"check_order": "ask_order_id", "reschedule": "ask_order_id_reschedule", "return_item": "ask_order_id_return"},
147
+ },
148
+ "ask_order_id": {
149
+ "ha": "Ka faɗi lambar oda naka.",
150
+ "en": "Say your order number.",
151
+ "expects": "digits",
152
+ "transitions": {"provide_digits": "order_status"},
153
+ },
154
+ "order_status": {
155
+ "ha": "Oda ɗinka yana kan hanya. Za a isar gobe da yamma.",
156
+ "en": "Your order is on the way. It will be delivered tomorrow evening.",
157
+ "expects": None, "terminal": True,
158
+ },
159
+ "ask_order_id_reschedule": {
160
+ "ha": "Ka faɗi lambar oda da kake son sake tsarawa.",
161
+ "en": "Say the order number you want to reschedule.",
162
+ "expects": "digits",
163
+ "transitions": {"provide_digits": "ask_new_date"},
164
+ },
165
+ "ask_new_date": {
166
+ "ha": "Wace rana kake so? Misali 'jumma'a' ko 'asabar'.",
167
+ "en": "Which day? For example 'Friday' or 'Saturday'.",
168
+ "expects": "date",
169
+ "transitions": {"provide_date": "reschedule_done"},
170
+ },
171
+ "reschedule_done": {
172
+ "ha": "An sake tsara isar. Za ka sami SMS na tabbatarwa.",
173
+ "en": "Delivery rescheduled. You'll receive a confirmation SMS.",
174
+ "expects": None, "terminal": True,
175
+ },
176
+ "ask_order_id_return": {
177
+ "ha": "Ka faɗi lambar oda da kake son mayarwa.",
178
+ "en": "Say the order number you want to return.",
179
+ "expects": "digits",
180
+ "transitions": {"provide_digits": "return_reason"},
181
+ },
182
+ "return_reason": {
183
+ "ha": "Me ya sa kake son mayarwa?",
184
+ "en": "Why do you want to return it?",
185
+ "expects": "text",
186
+ "transitions": {"provide_reason": "return_done"},
187
+ },
188
+ "return_done": {
189
+ "ha": "An karɓi buƙatarka. Wakili zai tattara kaya a gobe.",
190
+ "en": "Your request is received. An agent will collect the item tomorrow.",
191
+ "expects": None, "terminal": True,
192
+ },
193
+ },
194
+ },
195
+ }
196
+
197
+
198
+ def get_prompt(vertical: str, state_name: str) -> dict:
199
+ if state_name == "escalate_virtual":
200
+ return {"ha": "Zan juya ka ga wakili na mutum yanzu. Ka jira ɗan lokaci.",
201
+ "en": "I'll transfer you to a human agent now. Please hold."}
202
+ if state_name == "exit":
203
+ return {"ha": "Nagode. Sai watan.", "en": "Thank you. Goodbye."}
204
+ s = SCENARIOS[vertical]["states"].get(state_name)
205
+ if not s:
206
+ return {"ha": "Ban fahimci abin da ka ce ba.", "en": "I didn't understand."}
207
+ return {"ha": s["ha"], "en": s["en"]}
208
+
209
+
210
+ def get_expected_slot(vertical: str, state_name: str) -> Optional[str]:
211
+ s = SCENARIOS[vertical]["states"].get(state_name)
212
+ return s.get("expects") if s else None
213
+
214
+
215
+ def transition(state: DialogueState, intent: str, entities: dict) -> DialogueState:
216
+ state.turn_count += 1
217
+ for k, v in entities.items():
218
+ state.slots[k] = v
219
+
220
+ if intent == "human_agent" or state.turn_count > 12:
221
+ state.current_state = "escalate_virtual"
222
+ state.escalate_to_human = True
223
+ return state
224
+
225
+ current = SCENARIOS[state.vertical]["states"].get(state.current_state)
226
+ if not current:
227
+ state.current_state = "greeting"
228
+ return state
229
+
230
+ next_state = current.get("transitions", {}).get(intent)
231
+ if next_state:
232
+ state.current_state = next_state
233
+ target = SCENARIOS[state.vertical]["states"].get(next_state, {})
234
+ if target.get("escalate"):
235
+ state.escalate_to_human = True
236
+
237
+ return state
nlu.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NLU — Hybrid Hausa intent + entity extraction.
3
+
4
+ Three-tier architecture:
5
+ 1. Rule-based keyword matcher (fast path, ~80% of demo utterances)
6
+ 2. Qwen2.5-1.5B-Instruct zero-shot JSON extractor (paraphrases, novel phrasings)
7
+ 3. Rule-based fallback (if LLM fails or returns unparseable output)
8
+
9
+ The LLM is lazy-loaded on first non-matched utterance so the Space boots fast.
10
+ In production this would be replaced with a fine-tuned classifier on
11
+ PlotWeaver's Hausa intent corpus.
12
+ """
13
+ from __future__ import annotations
14
+ import re
15
+ import json
16
+ import logging
17
+ from typing import Optional
18
+
19
+ logger = logging.getLogger("plotweaver.nlu")
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Layer 1: rule-based fast path (covers common demo phrases)
23
+ # ---------------------------------------------------------------------------
24
+ INTENT_KEYWORDS = {
25
+ "check_balance": ["duba", "ma'auni", "balance", "kudi", "asusu"],
26
+ "block_card": ["toshe", "kati", "block"],
27
+ "transfer_money": ["tura", "canji", "canjin", "aika", "transfer"],
28
+ "buy_airtime": ["airtime", "caji"],
29
+ "buy_bundle": ["bundle", "data", "intanet"],
30
+ "complaint": ["korafi", "matsala", "complain"],
31
+ "check_order": ["bincika", "order", "oda"],
32
+ "reschedule": ["sake tsara", "reschedule", "canja lokaci"],
33
+ "return_item": ["mayar", "mayarwa", "return"],
34
+ "human_agent": ["mutum", "wakili", "agent", "human"],
35
+ "yes": ["i ", " i", "eh", "haka ne", "yes", "ok", "okay"],
36
+ "no": ["a'a", "a'aa", "ba haka", " no", "no "],
37
+ }
38
+
39
+ WORD_DIGITS = {
40
+ "sifili": "0", "daya": "1", "ɗaya": "1", "biyu": "2", "uku": "3",
41
+ "hudu": "4", "huɗu": "4", "biyar": "5", "shida": "6", "bakwai": "7",
42
+ "takwas": "8", "tara": "9",
43
+ }
44
+
45
+ WORD_AMOUNTS = {
46
+ "dubu goma": 10000, "dubu biyar": 5000, "dubu biyu": 2000,
47
+ "dubu": 1000, "ɗari biyar": 500, "dari biyar": 500,
48
+ "ɗari": 100, "dari": 100,
49
+ }
50
+
51
+
52
+ def _norm(t: str) -> str:
53
+ return " " + t.lower().strip() + " "
54
+
55
+
56
+ def _match_intent_kw(text: str) -> Optional[str]:
57
+ t = _norm(text)
58
+ for intent, kws in INTENT_KEYWORDS.items():
59
+ for kw in kws:
60
+ if kw in t:
61
+ return intent
62
+ return None
63
+
64
+
65
+ def _extract_digits(text: str) -> Optional[str]:
66
+ m = re.findall(r"\d+", text)
67
+ if m:
68
+ return "".join(m)
69
+ tokens = text.lower().split()
70
+ d = [WORD_DIGITS[tok] for tok in tokens if tok in WORD_DIGITS]
71
+ return "".join(d) if d else None
72
+
73
+
74
+ def _extract_amount(text: str) -> Optional[int]:
75
+ m = re.search(r"\d+", text)
76
+ if m:
77
+ return int(m.group())
78
+ t = text.lower()
79
+ for phrase in sorted(WORD_AMOUNTS.keys(), key=len, reverse=True):
80
+ if phrase in t:
81
+ return WORD_AMOUNTS[phrase]
82
+ return None
83
+
84
+
85
+ def _rule_based_parse(text: str, expected: Optional[str]) -> tuple[str, dict]:
86
+ """Layer 1 + 3: deterministic keyword + slot matcher."""
87
+ entities: dict = {}
88
+ if not text or not text.strip():
89
+ return "unknown", entities
90
+
91
+ # Universal escape
92
+ if _match_intent_kw(text) == "human_agent":
93
+ return "human_agent", entities
94
+
95
+ if expected == "digits":
96
+ d = _extract_digits(text)
97
+ if d:
98
+ entities["digits"] = d
99
+ return "provide_digits", entities
100
+
101
+ if expected == "amount":
102
+ a = _extract_amount(text)
103
+ if a is not None:
104
+ entities["amount"] = a
105
+ return "provide_amount", entities
106
+
107
+ if expected == "name":
108
+ name = text.strip().split()[-1] if text.strip() else ""
109
+ if name:
110
+ entities["name"] = name
111
+ return "provide_name", entities
112
+
113
+ if expected == "date":
114
+ entities["date"] = text.strip()
115
+ return "provide_date", entities
116
+
117
+ if expected == "bundle":
118
+ t = text.lower()
119
+ for b in ("rana", "mako", "wata"):
120
+ if b in t:
121
+ entities["bundle"] = b
122
+ return "provide_bundle", entities
123
+
124
+ if expected == "text":
125
+ entities["text"] = text.strip()
126
+ return "provide_text", entities
127
+
128
+ if expected == "yesno":
129
+ i = _match_intent_kw(text)
130
+ if i in ("yes", "no"):
131
+ return i, entities
132
+
133
+ i = _match_intent_kw(text)
134
+ if i:
135
+ return i, entities
136
+
137
+ return "unknown", entities
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Layer 2: Qwen2.5-1.5B-Instruct zero-shot NLU
142
+ # ---------------------------------------------------------------------------
143
+ _llm_model = None
144
+ _llm_tokenizer = None
145
+ _llm_failed = False # set to True after any load failure, to prevent retries
146
+
147
+
148
+ def _load_llm():
149
+ """Lazy-load Qwen2.5-1.5B-Instruct. Called only when rule-based misses."""
150
+ global _llm_model, _llm_tokenizer, _llm_failed
151
+ if _llm_failed:
152
+ return None, None
153
+ if _llm_model is not None:
154
+ return _llm_model, _llm_tokenizer
155
+ try:
156
+ import torch
157
+ from transformers import AutoModelForCausalLM, AutoTokenizer
158
+ logger.info("Loading Qwen2.5-1.5B-Instruct for NLU…")
159
+ model_id = "Qwen/Qwen2.5-1.5B-Instruct"
160
+ _llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
161
+ _llm_model = AutoModelForCausalLM.from_pretrained(
162
+ model_id,
163
+ torch_dtype=torch.float32, # CPU — bfloat16 not broadly supported
164
+ low_cpu_mem_usage=True,
165
+ )
166
+ _llm_model.eval()
167
+ logger.info("Qwen2.5-1.5B-Instruct ready.")
168
+ return _llm_model, _llm_tokenizer
169
+ except Exception as e:
170
+ logger.warning(f"LLM load failed: {e}")
171
+ _llm_failed = True
172
+ return None, None
173
+
174
+
175
+ # Candidate intents per expected-slot context. Keeps the LLM prompt small
176
+ # and constrains output to valid options only.
177
+ CANDIDATE_INTENTS = {
178
+ None: ["check_balance", "block_card", "transfer_money",
179
+ "buy_airtime", "buy_bundle", "complaint",
180
+ "check_order", "reschedule", "return_item",
181
+ "human_agent", "unknown"],
182
+ "intent": ["check_balance", "block_card", "transfer_money",
183
+ "buy_airtime", "buy_bundle", "complaint",
184
+ "check_order", "reschedule", "return_item",
185
+ "human_agent", "unknown"],
186
+ "yesno": ["yes", "no", "human_agent", "unknown"],
187
+ "digits": ["provide_digits", "human_agent", "unknown"],
188
+ "amount": ["provide_amount", "human_agent", "unknown"],
189
+ "name": ["provide_name", "human_agent", "unknown"],
190
+ "date": ["provide_date", "human_agent", "unknown"],
191
+ "bundle": ["provide_bundle", "human_agent", "unknown"],
192
+ "text": ["provide_text", "human_agent", "unknown"],
193
+ }
194
+
195
+
196
+ SYSTEM_PROMPT = """You are an intent classifier for a Hausa-language customer service voice agent.
197
+
198
+ Analyze the user's Hausa utterance and return a JSON object with:
199
+ - "intent": one of the candidate intents provided
200
+ - "entities": a dict of extracted values (may be empty)
201
+
202
+ Intent meanings:
203
+ - check_balance: user wants to check their account balance
204
+ - block_card: user wants to block or freeze their bank card
205
+ - transfer_money: user wants to transfer or send money
206
+ - buy_airtime: user wants to buy phone airtime
207
+ - buy_bundle: user wants to buy a data bundle
208
+ - complaint: user wants to file a complaint
209
+ - check_order: user wants to check an order status
210
+ - reschedule: user wants to reschedule a delivery
211
+ - return_item: user wants to return an item
212
+ - human_agent: user wants to speak to a human
213
+ - yes / no: affirmative or negative response
214
+ - provide_digits / provide_amount / provide_name / provide_date / provide_bundle / provide_text: user is providing specific information
215
+ - unknown: cannot determine the intent
216
+
217
+ Return ONLY a valid JSON object, no explanation. Example: {"intent": "check_balance", "entities": {}}"""
218
+
219
+
220
+ def _llm_parse(text: str, expected: Optional[str]) -> Optional[tuple[str, dict]]:
221
+ """Layer 2: zero-shot LLM classification. Returns None on any failure."""
222
+ model, tokenizer = _load_llm()
223
+ if model is None:
224
+ return None
225
+
226
+ candidates = CANDIDATE_INTENTS.get(expected, CANDIDATE_INTENTS[None])
227
+ user_prompt = (
228
+ f'Hausa utterance: "{text}"\n'
229
+ f'Expected slot type: {expected or "any"}\n'
230
+ f'Candidate intents: {", ".join(candidates)}\n\n'
231
+ 'Respond with JSON only.'
232
+ )
233
+ messages = [
234
+ {"role": "system", "content": SYSTEM_PROMPT},
235
+ {"role": "user", "content": user_prompt},
236
+ ]
237
+ try:
238
+ import torch
239
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
240
+ inputs = tokenizer(prompt, return_tensors="pt")
241
+ with torch.no_grad():
242
+ out = model.generate(
243
+ **inputs,
244
+ max_new_tokens=80,
245
+ do_sample=False,
246
+ pad_token_id=tokenizer.eos_token_id,
247
+ )
248
+ generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
249
+ logger.info(f"LLM raw output: {generated}")
250
+
251
+ # Extract JSON (model sometimes wraps it in markdown fences or prose)
252
+ m = re.search(r"\{.*?\}", generated, re.DOTALL)
253
+ if not m:
254
+ return None
255
+ parsed = json.loads(m.group())
256
+ intent = parsed.get("intent", "unknown")
257
+ entities = parsed.get("entities", {}) or {}
258
+ if not isinstance(entities, dict):
259
+ entities = {}
260
+ # Validate intent is in candidate list
261
+ if intent not in candidates:
262
+ logger.info(f"LLM returned out-of-candidate intent: {intent}")
263
+ return None
264
+ return intent, entities
265
+ except Exception as e:
266
+ logger.warning(f"LLM inference failed: {e}")
267
+ return None
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Public API
272
+ # ---------------------------------------------------------------------------
273
+ def parse(text: str, expected: Optional[str] = None,
274
+ use_llm: bool = True) -> tuple[str, dict, str]:
275
+ """
276
+ Hybrid NLU. Returns (intent, entities, source) where source is one of
277
+ 'rule', 'llm', or 'rule_fallback'.
278
+
279
+ Flow:
280
+ 1. Try rule-based keyword/slot matcher (fast, deterministic)
281
+ 2. If result is 'unknown' AND use_llm=True: try Qwen2.5 zero-shot
282
+ 3. If LLM fails or returns invalid output: return rule-based 'unknown'
283
+ """
284
+ intent, entities = _rule_based_parse(text, expected)
285
+
286
+ if intent != "unknown":
287
+ return intent, entities, "rule"
288
+
289
+ if not use_llm:
290
+ return intent, entities, "rule"
291
+
292
+ # Rule-based missed — try LLM
293
+ llm_result = _llm_parse(text, expected)
294
+ if llm_result is None:
295
+ return intent, entities, "rule_fallback"
296
+
297
+ llm_intent, llm_entities = llm_result
298
+
299
+ # Sanity-check entities for slot-typed expected (LLM might hallucinate
300
+ # digits; re-run our deterministic extractors for strict-format slots)
301
+ if expected == "digits":
302
+ d = _extract_digits(text)
303
+ if d:
304
+ llm_entities["digits"] = d
305
+ elif expected == "amount":
306
+ a = _extract_amount(text)
307
+ if a is not None:
308
+ llm_entities["amount"] = a
309
+
310
+ return llm_intent, llm_entities, "llm"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.48.0
2
+ torch==2.5.1
3
+ accelerate==1.2.1
4
+ numpy==2.1.3
5
+ scipy==1.15.0
6
+ sentencepiece==0.2.0
7
+ audioop-lts==0.2.2