Reethika30 commited on
Commit
2e92cbb
·
verified ·
1 Parent(s): ec44144

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Web Demo for Seq2Seq Document Generation
3
+ ================================================
4
+ Interactive UI: paste a long source document, get a generated summary
5
+ with greedy or beam search decoding, plus an attention heatmap.
6
+
7
+ Run locally: python app.py
8
+ Deploy: HuggingFace Spaces (Gradio SDK)
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import io
14
+ import base64
15
+ import torch
16
+ import gradio as gr
17
+ import matplotlib
18
+ matplotlib.use("Agg")
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+
22
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
23
+
24
+ from model import build_model
25
+ from inference import generate_document, greedy_decode
26
+ from preprocessing import Vocabulary, tokenize
27
+
28
+ # ----------------------------------------------------------------------
29
+ # Load model once at startup
30
+ # ----------------------------------------------------------------------
31
+ BASE = os.path.dirname(os.path.abspath(__file__))
32
+ DEVICE = torch.device("cpu")
33
+
34
+ print("Loading vocabularies...")
35
+ SRC_VOCAB = Vocabulary.load(os.path.join(BASE, "data", "src_vocab.pkl"))
36
+ TGT_VOCAB = Vocabulary.load(os.path.join(BASE, "data", "tgt_vocab.pkl"))
37
+
38
+ print("Loading model...")
39
+ MODEL = build_model(
40
+ src_vocab_size=len(SRC_VOCAB),
41
+ tgt_vocab_size=len(TGT_VOCAB),
42
+ embed_dim=256, hidden_dim=256, attention_dim=128,
43
+ n_layers=2, dropout=0.3,
44
+ pad_idx=Vocabulary.PAD_IDX, sos_idx=Vocabulary.SOS_IDX,
45
+ device=DEVICE,
46
+ )
47
+ ckpt = torch.load(os.path.join(BASE, "models", "best_model.pt"),
48
+ map_location=DEVICE, weights_only=False)
49
+ MODEL.load_state_dict(ckpt["model_state_dict"])
50
+ MODEL.eval()
51
+ print(f"Loaded checkpoint from epoch {ckpt.get('epoch')} "
52
+ f"(val_loss: {ckpt.get('val_loss'):.4f})")
53
+
54
+
55
+ # ----------------------------------------------------------------------
56
+ # Helpers
57
+ # ----------------------------------------------------------------------
58
+ def attention_heatmap(source_text, output_tokens, attn_matrix):
59
+ """Render attention weights as a matplotlib figure."""
60
+ src_tokens = tokenize(source_text)[:attn_matrix.shape[1] - 2] # trim special
61
+ out_tokens = output_tokens[:attn_matrix.shape[0]]
62
+
63
+ # Trim attention matrix to match displayed tokens
64
+ attn = attn_matrix[:len(out_tokens), :len(src_tokens) + 2]
65
+
66
+ fig, ax = plt.subplots(figsize=(max(8, len(src_tokens) * 0.25),
67
+ max(4, len(out_tokens) * 0.35)))
68
+ im = ax.imshow(attn, cmap="viridis", aspect="auto")
69
+ ax.set_xticks(range(len(src_tokens) + 2))
70
+ ax.set_xticklabels(["<SOS>"] + src_tokens + ["<EOS>"],
71
+ rotation=75, fontsize=8)
72
+ ax.set_yticks(range(len(out_tokens)))
73
+ ax.set_yticklabels(out_tokens, fontsize=9)
74
+ ax.set_xlabel("Source Tokens")
75
+ ax.set_ylabel("Generated Tokens")
76
+ ax.set_title("Bahdanau Attention Weights")
77
+ plt.colorbar(im, ax=ax, fraction=0.025)
78
+ plt.tight_layout()
79
+ return fig
80
+
81
+
82
+ def generate(source_text, method, beam_width, max_len):
83
+ """Main inference function called by Gradio."""
84
+ if not source_text.strip():
85
+ return "Please enter source text.", None, ""
86
+
87
+ try:
88
+ if method == "Greedy":
89
+ text, meta = generate_document(
90
+ MODEL, source_text, SRC_VOCAB, TGT_VOCAB,
91
+ method="greedy", max_len=int(max_len), device=DEVICE
92
+ )
93
+ # Build attention heatmap
94
+ fig = None
95
+ attn = meta.get("attention")
96
+ if attn is not None and hasattr(attn, "shape"):
97
+ out_tokens = text.split()
98
+ try:
99
+ fig = attention_heatmap(source_text, out_tokens,
100
+ attn.cpu().numpy() if torch.is_tensor(attn) else np.asarray(attn))
101
+ except Exception as e:
102
+ print(f"Heatmap error: {e}")
103
+ info = f"Method: Greedy decode | Output length: {len(text.split())} tokens"
104
+ return text, fig, info
105
+
106
+ else: # Beam
107
+ text, meta = generate_document(
108
+ MODEL, source_text, SRC_VOCAB, TGT_VOCAB,
109
+ method="beam", beam_width=int(beam_width),
110
+ max_len=int(max_len), device=DEVICE
111
+ )
112
+ info = (f"Method: Beam Search (width={int(beam_width)}) | "
113
+ f"Score: {meta.get('score', 0):.4f} | "
114
+ f"Output length: {len(text.split())} tokens")
115
+ return text, None, info
116
+
117
+ except Exception as e:
118
+ return f"Error: {e}", None, ""
119
+
120
+
121
+ # ----------------------------------------------------------------------
122
+ # Gradio UI
123
+ # ----------------------------------------------------------------------
124
+ EXAMPLES = [
125
+ ["The quarterly financial report for TechNova indicates revenue of $2500M, "
126
+ "representing a 15% increase year over year. Operating expenses increased "
127
+ "to $1200M. Net income was $450M. The board approved a dividend of $2.50 "
128
+ "per share. Management projects continued growth in the coming quarters "
129
+ "driven by AI integration.", "Greedy", 5, 60],
130
+
131
+ ["The ProMax X1 by CloudPeak features a 8-core processor, 6000mAh battery, "
132
+ "and AI-powered assistant. It is designed for professionals who need high "
133
+ "performance computing. Available in Black, Silver, and Blue, the device "
134
+ "weighs 195g and includes fast charging and biometric auth. Pricing starts "
135
+ "at $999.", "Beam", 5, 60],
136
+
137
+ ["This study examines the relationship between remote work frequency and "
138
+ "productivity using a dataset of 5000 observations from Fortune 500 "
139
+ "companies. We employ regression analysis to analyze temporal patterns. "
140
+ "Results indicate a strong positive correlation (p < 0.001). The findings "
141
+ "suggest that targeted interventions improve outcomes.", "Greedy", 5, 60],
142
+ ]
143
+
144
+ DESCRIPTION = """
145
+ # Seq2Seq Document Generation with Bahdanau Attention
146
+
147
+ Encoder-decoder model that compresses long-form documents (financial reports,
148
+ product specs, research abstracts) into concise summaries.
149
+
150
+ - **Architecture:** Bidirectional GRU Encoder + Bahdanau (Additive) Attention + GRU Decoder
151
+ - **Parameters:** 3.9M | **Framework:** PyTorch
152
+ - **Training:** 15 epochs on 5,000 synthetic pairs | **Best Val PPL:** 9.08
153
+ - **Decoding:** Greedy + Beam Search (width 5)
154
+
155
+ Paste a document below, pick a decoding strategy, and see the model
156
+ generate a summary. Greedy mode also renders the **attention heatmap**
157
+ showing which source tokens the decoder focused on at each step.
158
+ """
159
+
160
+ with gr.Blocks(title="Seq2Seq Doc Generation") as demo:
161
+ gr.Markdown(DESCRIPTION)
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=3):
165
+ src = gr.Textbox(
166
+ label="Source Document",
167
+ placeholder="Paste a long document here...",
168
+ lines=8,
169
+ )
170
+ with gr.Row():
171
+ method = gr.Radio(
172
+ ["Greedy", "Beam"], value="Greedy",
173
+ label="Decoding Method"
174
+ )
175
+ beam_width = gr.Slider(2, 10, value=5, step=1,
176
+ label="Beam Width")
177
+ max_len = gr.Slider(20, 120, value=60, step=5,
178
+ label="Max Output Tokens")
179
+ btn = gr.Button("Generate", variant="primary")
180
+
181
+ with gr.Column(scale=2):
182
+ output = gr.Textbox(label="Generated Summary", lines=4)
183
+ info = gr.Textbox(label="Decoding Info", lines=1)
184
+ heatmap = gr.Plot(label="Attention Heatmap (Greedy only)")
185
+
186
+ gr.Examples(
187
+ examples=EXAMPLES,
188
+ inputs=[src, method, beam_width, max_len],
189
+ outputs=[output, heatmap, info],
190
+ fn=generate,
191
+ cache_examples=False,
192
+ )
193
+
194
+ btn.click(generate,
195
+ inputs=[src, method, beam_width, max_len],
196
+ outputs=[output, heatmap, info])
197
+
198
+ gr.Markdown(
199
+ "---\n"
200
+ "**Repo:** [github.com/Reethika30/nlp-seq2seq-docgen]"
201
+ "(https://github.com/Reethika30/nlp-seq2seq-docgen)"
202
+ )
203
+
204
+
205
+ if __name__ == "__main__":
206
+ demo.launch(server_name="0.0.0.0", share=False)