Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -3,16 +3,28 @@ os.environ['HF_HOME'] = '/tmp'
|
|
| 3 |
|
| 4 |
from flask import Flask, request, jsonify, render_template
|
| 5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 6 |
-
import
|
| 7 |
|
| 8 |
app = Flask(__name__)
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 12 |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 13 |
-
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) #
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
LENGTH_PRESETS = {
|
| 17 |
"short": {"min_length": 20, "max_length": 60},
|
| 18 |
"medium": {"min_length": 60, "max_length": 130},
|
|
@@ -47,31 +59,127 @@ def apply_tone_instruction(text, tone):
|
|
| 47 |
instr = "Summarize:"
|
| 48 |
return f"{instr}\n\n{text}"
|
| 49 |
|
| 50 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@app.route("/")
|
| 52 |
def home():
|
|
|
|
| 53 |
return render_template("index.html")
|
| 54 |
|
| 55 |
@app.route("/summarize", methods=["POST"])
|
| 56 |
def summarize_route():
|
|
|
|
| 57 |
data = request.get_json(force=True)
|
| 58 |
-
text = data.get("text", "")[:20000]
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
if not text or len(text.split()) < 5:
|
| 63 |
return jsonify({"error": "Input too short."}), 400
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
chunks = chunk_text_by_chars(text, max_chars=1500, overlap=200)
|
| 67 |
summaries = []
|
| 68 |
|
| 69 |
for chunk in chunks:
|
| 70 |
prompted = apply_tone_instruction(chunk, tone)
|
|
|
|
|
|
|
|
|
|
| 71 |
out = summarizer(
|
| 72 |
prompted,
|
| 73 |
-
min_length=
|
| 74 |
-
max_length=
|
| 75 |
truncation=True
|
| 76 |
)[0]["summary_text"]
|
| 77 |
summaries.append(out.strip())
|
|
@@ -92,7 +200,9 @@ def summarize_route():
|
|
| 92 |
lines = [l.strip() for s in final.splitlines() for l in s.split(". ") if l.strip()]
|
| 93 |
final = "\n".join(f"- {l.rstrip('.')}" for l in lines[:20])
|
| 94 |
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
if __name__ == "__main__":
|
| 98 |
-
|
|
|
|
|
|
| 3 |
|
| 4 |
from flask import Flask, request, jsonify, render_template
|
| 5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 6 |
+
import json, re, time
|
| 7 |
|
| 8 |
app = Flask(__name__)
|
| 9 |
|
| 10 |
+
# -------------------------
|
| 11 |
+
# Models (CPU as requested)
|
| 12 |
+
# -------------------------
|
| 13 |
+
# Primary summarizer: higher-quality model
|
| 14 |
+
MODEL_NAME = "facebook/bart-large-cnn"
|
| 15 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 16 |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
| 17 |
+
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1) # CPU
|
| 18 |
|
| 19 |
+
# Small instruction model to choose length/tone when "auto" is requested
|
| 20 |
+
PARAM_MODEL_NAME = "google/flan-t5-small"
|
| 21 |
+
param_tokenizer = AutoTokenizer.from_pretrained(PARAM_MODEL_NAME)
|
| 22 |
+
param_model = AutoModelForSeq2SeqLM.from_pretrained(PARAM_MODEL_NAME)
|
| 23 |
+
param_generator = pipeline("text2text-generation", model=param_model, tokenizer=param_tokenizer, device=-1) # CPU
|
| 24 |
+
|
| 25 |
+
# -------------------------
|
| 26 |
+
# Presets & helpers
|
| 27 |
+
# -------------------------
|
| 28 |
LENGTH_PRESETS = {
|
| 29 |
"short": {"min_length": 20, "max_length": 60},
|
| 30 |
"medium": {"min_length": 60, "max_length": 130},
|
|
|
|
| 59 |
instr = "Summarize:"
|
| 60 |
return f"{instr}\n\n{text}"
|
| 61 |
|
| 62 |
+
# small regex int extractor
|
| 63 |
+
def _first_int_from_text(s, fallback=None):
|
| 64 |
+
m = re.search(r"\d{1,5}", s)
|
| 65 |
+
return int(m.group()) if m else fallback
|
| 66 |
+
|
| 67 |
+
def generate_summarization_config(text):
|
| 68 |
+
"""
|
| 69 |
+
Uses the small instruction model to recommend:
|
| 70 |
+
- length: short|medium|long
|
| 71 |
+
- min_words, max_words (integers)
|
| 72 |
+
- tone: neutral|formal|casual|bullet
|
| 73 |
+
Returns a normalized dict with keys: length, min_length, max_length, tone
|
| 74 |
+
Falls back to heuristics on failure.
|
| 75 |
+
"""
|
| 76 |
+
prompt = (
|
| 77 |
+
"You are a helpful assistant that recommends summarization settings.\n"
|
| 78 |
+
"Given the following source text, pick a summary LENGTH category (short/medium/long), "
|
| 79 |
+
"an estimated MIN and MAX length in words for the summary, and a TONE (neutral/formal/casual/bullet).\n"
|
| 80 |
+
"Respond in a single line in this exact JSON format (no extra commentary):\n"
|
| 81 |
+
'{"length":"SHORT_OR_MEDIUM_OR_LONG","min_words":MIN,"max_words":MAX,"tone":"NEUTRAL|FORMAL|CASUAL|BULLET"}\n\n'
|
| 82 |
+
"Text:\n'''"
|
| 83 |
+
+ (text[:6000])
|
| 84 |
+
+ "'''"
|
| 85 |
+
)
|
| 86 |
+
try:
|
| 87 |
+
out = param_generator(prompt, max_length=200, do_sample=False)[0]["generated_text"].strip()
|
| 88 |
+
# attempt parse
|
| 89 |
+
try:
|
| 90 |
+
cfg = json.loads(out)
|
| 91 |
+
except Exception:
|
| 92 |
+
jmatch = re.search(r"\{.*\}", out, re.DOTALL)
|
| 93 |
+
if jmatch:
|
| 94 |
+
raw = jmatch.group()
|
| 95 |
+
raw = raw.replace("'", '"')
|
| 96 |
+
cfg = json.loads(raw)
|
| 97 |
+
else:
|
| 98 |
+
raise
|
| 99 |
+
|
| 100 |
+
length = cfg.get("length", "").lower()
|
| 101 |
+
tone = cfg.get("tone", "").lower()
|
| 102 |
+
min_w = cfg.get("min_words")
|
| 103 |
+
max_w = cfg.get("max_words")
|
| 104 |
+
|
| 105 |
+
# sensible defaults if parse odd
|
| 106 |
+
if length not in ("short", "medium", "long"):
|
| 107 |
+
words = len(text.split())
|
| 108 |
+
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 109 |
+
if tone not in ("neutral", "formal", "casual", "bullet"):
|
| 110 |
+
tone = "neutral"
|
| 111 |
+
|
| 112 |
+
# fallback numeric extraction
|
| 113 |
+
if not isinstance(min_w, int):
|
| 114 |
+
min_w = _first_int_from_text(out, fallback=None)
|
| 115 |
+
if not isinstance(max_w, int):
|
| 116 |
+
max_w = _first_int_from_text(out[::-1], fallback=None)
|
| 117 |
+
|
| 118 |
+
defaults = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
|
| 119 |
+
dmin, dmax = defaults.get(length, (50, 130))
|
| 120 |
+
min_len = int(min_w) if isinstance(min_w, int) else dmin
|
| 121 |
+
max_len = int(max_w) if isinstance(max_w, int) else dmax
|
| 122 |
+
|
| 123 |
+
# clamp to sane bounds
|
| 124 |
+
min_len = max(5, min(min_len, 2000))
|
| 125 |
+
max_len = max(min_len + 5, min(max_len, 4000))
|
| 126 |
+
|
| 127 |
+
return {"length": length, "min_length": min_len, "max_length": max_len, "tone": tone}
|
| 128 |
+
except Exception:
|
| 129 |
+
# fallback heuristic
|
| 130 |
+
words = len(text.split())
|
| 131 |
+
length = "short" if words < 150 else ("medium" if words < 800 else "long")
|
| 132 |
+
d = {"short": (15, 50), "medium": (50, 130), "long": (130, 300)}
|
| 133 |
+
mn, mx = d[length]
|
| 134 |
+
return {"length": length, "min_length": mn, "max_length": mx, "tone": "neutral"}
|
| 135 |
+
|
| 136 |
+
# -------------------------
|
| 137 |
+
# Routes
|
| 138 |
+
# -------------------------
|
| 139 |
@app.route("/")
|
| 140 |
def home():
|
| 141 |
+
# expects templates/index.html to exist (your frontend)
|
| 142 |
return render_template("index.html")
|
| 143 |
|
| 144 |
@app.route("/summarize", methods=["POST"])
|
| 145 |
def summarize_route():
|
| 146 |
+
start_time = time.time()
|
| 147 |
data = request.get_json(force=True)
|
| 148 |
+
text = data.get("text", "")[:20000] # cap input
|
| 149 |
+
requested_length = (data.get("length") or "medium").lower()
|
| 150 |
+
requested_tone = (data.get("tone") or "neutral").lower()
|
| 151 |
|
| 152 |
if not text or len(text.split()) < 5:
|
| 153 |
return jsonify({"error": "Input too short."}), 400
|
| 154 |
|
| 155 |
+
# If user asks AI to choose settings
|
| 156 |
+
if requested_length in ("auto", "ai") or requested_tone in ("auto", "ai"):
|
| 157 |
+
cfg = generate_summarization_config(text)
|
| 158 |
+
length = cfg.get("length", "medium")
|
| 159 |
+
tone = cfg.get("tone", "neutral")
|
| 160 |
+
preset_min = cfg.get("min_length")
|
| 161 |
+
preset_max = cfg.get("max_length")
|
| 162 |
+
preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
|
| 163 |
+
else:
|
| 164 |
+
length = requested_length if requested_length in LENGTH_PRESETS else "medium"
|
| 165 |
+
tone = requested_tone if requested_tone in ("neutral", "formal", "casual", "bullet") else "neutral"
|
| 166 |
+
preset = LENGTH_PRESETS.get(length, LENGTH_PRESETS["medium"])
|
| 167 |
+
preset_min = preset["min_length"]
|
| 168 |
+
preset_max = preset["max_length"]
|
| 169 |
+
|
| 170 |
+
# chunk input for long texts
|
| 171 |
chunks = chunk_text_by_chars(text, max_chars=1500, overlap=200)
|
| 172 |
summaries = []
|
| 173 |
|
| 174 |
for chunk in chunks:
|
| 175 |
prompted = apply_tone_instruction(chunk, tone)
|
| 176 |
+
min_l = int(preset_min) if preset_min is not None else preset["min_length"]
|
| 177 |
+
max_l = int(preset_max) if preset_max is not None else preset["max_length"]
|
| 178 |
+
|
| 179 |
out = summarizer(
|
| 180 |
prompted,
|
| 181 |
+
min_length=min_l,
|
| 182 |
+
max_length=max_l,
|
| 183 |
truncation=True
|
| 184 |
)[0]["summary_text"]
|
| 185 |
summaries.append(out.strip())
|
|
|
|
| 200 |
lines = [l.strip() for s in final.splitlines() for l in s.split(". ") if l.strip()]
|
| 201 |
final = "\n".join(f"- {l.rstrip('.')}" for l in lines[:20])
|
| 202 |
|
| 203 |
+
elapsed = time.time() - start_time
|
| 204 |
+
return jsonify({"summary": final, "meta": {"length_choice": length, "tone": tone, "time_seconds": round(elapsed, 2)}})
|
| 205 |
|
| 206 |
if __name__ == "__main__":
|
| 207 |
+
# keep debug off in production; using CPU as requested
|
| 208 |
+
app.run(host="0.0.0.0", port=7860, debug=True)
|