sayanAIAI commited on
Commit
d005cea
·
verified ·
1 Parent(s): a323f1e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +123 -13
main.py CHANGED
@@ -3,16 +3,28 @@ os.environ['HF_HOME'] = '/tmp'
3
 
4
  from flask import Flask, request, jsonify, render_template
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- import math, textwrap
7
 
8
  app = Flask(__name__)
9
 
10
- MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
 
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
13
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) # set device appropriately
14
 
15
- # Simple mapping of presets to generation lengths
 
 
 
 
 
 
 
 
16
  LENGTH_PRESETS = {
17
  "short": {"min_length": 20, "max_length": 60},
18
  "medium": {"min_length": 60, "max_length": 130},
@@ -47,31 +59,127 @@ def apply_tone_instruction(text, tone):
47
  instr = "Summarize:"
48
  return f"{instr}\n\n{text}"
49
 
50
- # NEW: Route to show summarizer.html (fixes 404)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.route("/")
52
  def home():
 
53
  return render_template("index.html")
54
 
55
  @app.route("/summarize", methods=["POST"])
56
  def summarize_route():
 
57
  data = request.get_json(force=True)
58
- text = data.get("text", "")[:20000]
59
- length = data.get("length", "medium")
60
- tone = data.get("tone", "neutral")
61
 
62
  if not text or len(text.split()) < 5:
63
  return jsonify({"error": "Input too short."}), 400
64
 
65
- preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  chunks = chunk_text_by_chars(text, max_chars=1500, overlap=200)
67
  summaries = []
68
 
69
  for chunk in chunks:
70
  prompted = apply_tone_instruction(chunk, tone)
 
 
 
71
  out = summarizer(
72
  prompted,
73
- min_length=preset["min_length"],
74
- max_length=preset["max_length"],
75
  truncation=True
76
  )[0]["summary_text"]
77
  summaries.append(out.strip())
@@ -92,7 +200,9 @@ def summarize_route():
92
  lines = [l.strip() for s in final.splitlines() for l in s.split(". ") if l.strip()]
93
  final = "\n".join(f"- {l.rstrip('.')}" for l in lines[:20])
94
 
95
- return jsonify({"summary": final})
 
96
 
97
  if __name__ == "__main__":
98
- app.run(debug=True, port=7860)
 
 
3
 
4
  from flask import Flask, request, jsonify, render_template
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ import json, re, time
7
 
8
  app = Flask(__name__)
9
 
10
+ # -------------------------
11
+ # Models (CPU as requested)
12
+ # -------------------------
13
+ # Primary summarizer: higher-quality model
14
+ MODEL_NAME = "facebook/bart-large-cnn"
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
17
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) # CPU
18
 
19
+ # Small instruction model to choose length/tone when "auto" is requested
20
+ PARAM_MODEL_NAME = "google/flan-t5-small"
21
+ param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL_NAME)
22
+ param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL_NAME)
23
+ param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=-1) # CPU
24
+
25
+ # -------------------------
26
+ # Presets & helpers
27
+ # -------------------------
28
  LENGTH_PRESETS = {
29
  "short": {"min_length": 20, "max_length": 60},
30
  "medium": {"min_length": 60, "max_length": 130},
 
59
  instr = "Summarize:"
60
  return f"{instr}\n\n{text}"
61
 
62
+ # small regex int extractor
63
+ def _first_int_from_text(s, fallback=None):
64
+ m = re.search(r"\d{1,5}", s)
65
+ return int(m.group()) if m else fallback
66
+
67
+ def generate_summarization_config(text):
68
+ """
69
+ Uses the small instruction model to recommend:
70
+ - length: short|medium|long
71
+ - min_words, max_words (integers)
72
+ - tone: neutral|formal|casual|bullet
73
+ Returns a normalized dict with keys: length, min_length, max_length, tone
74
+ Falls back to heuristics on failure.
75
+ """
76
+ prompt = (
77
+ "You are a helpful assistant that recommends summarization settings.\n"
78
+ "Given the following source text, pick a summary LENGTH category (short/medium/long), "
79
+ "an estimated MIN and MAX length in words for the summary, and a TONE (neutral/formal/casual/bullet).\n"
80
+ "Respond in a single line in this exact JSON format (no extra commentary):\n"
81
+ '{"length":"SHORT_OR_MEDIUM_OR_LONG","min_words":MIN,"max_words":MAX,"tone":"NEUTRAL|FORMAL|CASUAL|BULLET"}\n\n'
82
+ "Text:\n'''"
83
+ + (text[:6000])
84
+ + "'''"
85
+ )
86
+ try:
87
+ out = param_generator(prompt, max_length=200, do_sample=False)[0]["generated_text"].strip()
88
+ # attempt parse
89
+ try:
90
+ cfg = json.loads(out)
91
+ except Exception:
92
+ jmatch = re.search(r"\{.*\}", out, re.DOTALL)
93
+ if jmatch:
94
+ raw = jmatch.group()
95
+ raw = raw.replace("'", '"')
96
+ cfg = json.loads(raw)
97
+ else:
98
+ raise
99
+
100
+ length = cfg.get("length", "").lower()
101
+ tone = cfg.get("tone", "").lower()
102
+ min_w = cfg.get("min_words")
103
+ max_w = cfg.get("max_words")
104
+
105
+ # sensible defaults if parse odd
106
+ if length not in ("short", "medium", "long"):
107
+ words = len(text.split())
108
+ length = "short" if words < 150 else ("medium" if words < 800 else "long")
109
+ if tone not in ("neutral", "formal", "casual", "bullet"):
110
+ tone = "neutral"
111
+
112
+ # fallback numeric extraction
113
+ if not isinstance(min_w, int):
114
+ min_w = _first_int_from_text(out, fallback=None)
115
+ if not isinstance(max_w, int):
116
+ max_w = _first_int_from_text(out[::-1], fallback=None)
117
+
118
+ defaults = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
119
+ dmin, dmax = defaults.get(length, (50, 130))
120
+ min_len = int(min_w) if isinstance(min_w, int) else dmin
121
+ max_len = int(max_w) if isinstance(max_w, int) else dmax
122
+
123
+ # clamp to sane bounds
124
+ min_len = max(5, min(min_len, 2000))
125
+ max_len = max(min_len + 5, min(max_len, 4000))
126
+
127
+ return {"length": length, "min_length": min_len, "max_length": max_len, "tone": tone}
128
+ except Exception:
129
+ # fallback heuristic
130
+ words = len(text.split())
131
+ length = "short" if words < 150 else ("medium" if words < 800 else "long")
132
+ d = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
133
+ mn, mx = d[length]
134
+ return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
135
+
136
+ # -------------------------
137
+ # Routes
138
+ # -------------------------
139
  @app.route("/")
140
  def home():
141
+ # expects templates/index.html to exist (your frontend)
142
  return render_template("index.html")
143
 
144
  @app.route("/summarize", methods=["POST"])
145
  def summarize_route():
146
+ start_time = time.time()
147
  data = request.get_json(force=True)
148
+ text = data.get("text", "")[:20000] # cap input
149
+ requested_length = (data.get("length") or "medium").lower()
150
+ requested_tone = (data.get("tone") or "neutral").lower()
151
 
152
  if not text or len(text.split()) < 5:
153
  return jsonify({"error": "Input too short."}), 400
154
 
155
+ # If user asks AI to choose settings
156
+ if requested_length in ("auto", "ai") or requested_tone in ("auto", "ai"):
157
+ cfg = generate_summarization_config(text)
158
+ length = cfg.get("length", "medium")
159
+ tone = cfg.get("tone", "neutral")
160
+ preset_min = cfg.get("min_length")
161
+ preset_max = cfg.get("max_length")
162
+ preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
163
+ else:
164
+ length = requested_length if requested_length in LENGTH_PRESETS else "medium"
165
+ tone = requested_tone if requested_tone in ("neutral", "formal", "casual", "bullet") else "neutral"
166
+ preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
167
+ preset_min = preset["min_length"]
168
+ preset_max = preset["max_length"]
169
+
170
+ # chunk input for long texts
171
  chunks = chunk_text_by_chars(text, max_chars=1500, overlap=200)
172
  summaries = []
173
 
174
  for chunk in chunks:
175
  prompted = apply_tone_instruction(chunk, tone)
176
+ min_l = int(preset_min) if preset_min is not None else preset["min_length"]
177
+ max_l = int(preset_max) if preset_max is not None else preset["max_length"]
178
+
179
  out = summarizer(
180
  prompted,
181
+ min_length=min_l,
182
+ max_length=max_l,
183
  truncation=True
184
  )[0]["summary_text"]
185
  summaries.append(out.strip())
 
200
  lines = [l.strip() for s in final.splitlines() for l in s.split(". ") if l.strip()]
201
  final = "\n".join(f"- {l.rstrip('.')}" for l in lines[:20])
202
 
203
+ elapsed = time.time() - start_time
204
+ return jsonify({"summary": final, "meta": {"length_choice": length, "tone": tone, "time_seconds": round(elapsed, 2)}})
205
 
206
  if __name__ == "__main__":
207
+ # keep debug off in production; using CPU as requested
208
+ app.run(host="0.0.0.0", port=7860, debug=True)