query-fanout / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
6fe4e82 verified
# streamlit_app.py
import time
import torch
import streamlit as st
from typing import List, Tuple, Dict, Any
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
import torch.nn.functional as F
import pandas as pd
# ------------------ CONSTANTS ------------------
MODEL_PATH = "dejanseo/query-fanout"
CACHE_DIR = "/app/cache/huggingface"
MAX_INPUT_LENGTH = 32
MAX_TARGET_LENGTH = 16
# --- BATCHING CONFIGURATION ---
TOTAL_DESIRED_CANDIDATES = 200
GENERATION_BATCH_SIZE = 10
# ------------------ HARDCODED SETTINGS ------------------
GENERATION_CONFIG: Dict[str, Any] = {
"temperature": 1.10, "top_p": 0.98, "no_repeat_ngram_size": 2,
"repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len",
}
# ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
@st.cache_resource
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
model.to(device)
model.eval()
return tok, model, device
# ------------------ GENERATION HELPERS ------------------
def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device):
txt = f"For URL: {url} diversify query: {query}"
enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True)
return {k: v.to(device) for k, v in enc.items()}
def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
return tok.batch_decode(seqs, skip_special_tokens=True)
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
if not hasattr(gen, "scores"): return [float("nan")] * gen.sequences.size(0)
scores, seqs = gen.scores, gen.sequences
nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device)
for t in range(len(scores)):
step_logits, step_tok = scores[t], seqs[:, t + 1]
valid = step_tok.ne(pad_id) & (~finished)
if valid.any():
step_logprobs = F.log_softmax(step_logits, dim=-1)
gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1)
sum_logp += torch.where(valid, gather, torch.zeros_like(gather))
count += valid.float()
finished |= step_tok.eq(eos_id)
count = torch.where(count.eq(0), torch.ones_like(count), count)
return [(lp / c).item() for lp, c in zip(sum_logp, count)]
# --- UPDATED sampling_generate function (Deep Analysis) ---
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
kwargs = dict(
max_length=MAX_TARGET_LENGTH,
do_sample=True,
temperature=temperature,
top_p=top_p,
num_return_sequences=top_n,
return_dict_in_generate=True,
output_scores=True
)
if int(no_repeat_ngram_size) > 0:
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
if float(repetition_penalty) != 1.0:
kwargs["repetition_penalty"] = float(repetition_penalty)
if bad_words_ids:
kwargs["bad_words_ids"] = bad_words_ids
gen = model.generate(**inputs, **kwargs)
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
def normalize_text(s: str) -> str:
return " ".join(s.strip().lower().split())
# --- Beam-based quick function (from old script) ---
def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device, num_return_sequences: int = 10) -> List[str]:
input_text = f"For URL: {url} diversify query: {query}"
inputs = tok(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=MAX_TARGET_LENGTH,
num_return_sequences=num_return_sequences,
num_beams=num_return_sequences * 2,
num_beam_groups=num_return_sequences,
diversity_penalty=0.5,
temperature=0.8,
do_sample=False,
early_stopping=True,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id,
forced_eos_token_id=tok.eos_token_id,
max_new_tokens=MAX_TARGET_LENGTH,
)
expansions: List[str] = []
for seq in outputs:
s = tok.decode(seq, skip_special_tokens=True)
if s and normalize_text(s) != normalize_text(query):
expansions.append(s)
seen = set()
uniq = []
for s in expansions:
if s not in seen:
seen.add(s)
uniq.append(s)
return uniq
# ------------------ STREAMLIT APP ------------------
st.set_page_config(
page_title="Query Fan-Out by DEJAN AI",
page_icon="🔎",
layout="wide"
)
st.logo(
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
link="https://dejan.ai/",
)
tok, model, device = load_model()
st.title("Query Fanout Generator")
st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
# Inputs
col1, col2 = st.columns(2)
with col1:
url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
with col2:
query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
# Mode + single Run button
mode_high_effort = st.toggle("High Effort", value=False, help="On = Deep Analysis (stochastic sampling, large batch). Off = Quick Fan-Out (beam-based).")
run_btn = st.button("Generate", type="primary")
if run_btn:
if mode_high_effort:
# ---- Deep Analysis path (sampling, large batches) ----
cfg = GENERATION_CONFIG
with st.spinner("Generating queries..."):
start_ts = time.time()
inputs = build_inputs(tok, url, query, device)
all_texts, all_scores = [], []
seen_texts_for_bad_words = set()
num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
progress_bar = st.progress(0)
for i in range(num_batches):
current_seed = cfg["seed"] + i
torch.manual_seed(current_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(current_seed)
bad_words_ids = None
if seen_texts_for_bad_words:
bad_words_ids = tok(
list(seen_texts_for_bad_words),
add_special_tokens=False,
padding=True,
truncation=True
)["input_ids"]
batch_texts, batch_scores = sampling_generate(
tok, model, device, inputs,
top_n=GENERATION_BATCH_SIZE,
temperature=float(cfg["temperature"]),
top_p=float(cfg["top_p"]),
no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
repetition_penalty=float(cfg["repetition_penalty"]),
bad_words_ids=bad_words_ids
)
all_texts.extend(batch_texts)
all_scores.extend(batch_scores)
for txt in batch_texts:
if txt:
seen_texts_for_bad_words.add(txt)
progress_bar.progress((i + 1) / num_batches)
# Deduplicate and finalize
final_enriched = []
final_seen_normalized = set()
for txt, sc in zip(all_texts, all_scores):
norm = normalize_text(txt)
if norm and norm not in final_seen_normalized and norm != query.lower():
final_seen_normalized.add(norm)
final_enriched.append({"logp/len": sc, "text": txt})
if cfg["sort_by"] == "logp/len":
final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
if not final_enriched:
st.warning("No queries were generated. Try a different input.")
else:
output_texts = [item['text'] for item in final_enriched]
df = pd.DataFrame(output_texts, columns=["Generated Query"])
df.index = range(1, len(df) + 1)
st.dataframe(df, use_container_width=True)
else:
# ---- Quick Fan-Out path (beam-based, small and simple) ----
with st.spinner("Generating quick fan-out..."):
start_time = time.time()
expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
if expansions:
df = pd.DataFrame(expansions, columns=["Generated Query"])
df.index = range(1, len(df) + 1)
st.dataframe(df, use_container_width=True)
else:
st.warning("No valid fan-outs generated. Try a different query.")