Spaces:
Running
Running
# streamlit_app.py | |
import time | |
import torch | |
import streamlit as st | |
from typing import List, Tuple, Dict, Any | |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer | |
import torch.nn.functional as F | |
import pandas as pd | |
# ------------------ CONSTANTS ------------------ | |
MODEL_PATH = "dejanseo/query-fanout" | |
CACHE_DIR = "/app/cache/huggingface" | |
MAX_INPUT_LENGTH = 32 | |
MAX_TARGET_LENGTH = 16 | |
# --- BATCHING CONFIGURATION --- | |
TOTAL_DESIRED_CANDIDATES = 200 | |
GENERATION_BATCH_SIZE = 10 | |
# ------------------ HARDCODED SETTINGS ------------------ | |
GENERATION_CONFIG: Dict[str, Any] = { | |
"temperature": 1.10, "top_p": 0.98, "no_repeat_ngram_size": 2, | |
"repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len", | |
} | |
# ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------ | |
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]: | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR) | |
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR) | |
model.to(device) | |
model.eval() | |
return tok, model, device | |
# ------------------ GENERATION HELPERS ------------------ | |
def build_inputs(tok: MT5Tokenizer, url: str, query: str, device: torch.device): | |
txt = f"For URL: {url} diversify query: {query}" | |
enc = tok(txt, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True) | |
return {k: v.to(device) for k, v in enc.items()} | |
def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]: | |
return tok.batch_decode(seqs, skip_special_tokens=True) | |
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]: | |
if not hasattr(gen, "scores"): return [float("nan")] * gen.sequences.size(0) | |
scores, seqs = gen.scores, gen.sequences | |
nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id | |
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device) | |
count = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device) | |
finished = torch.zeros(nseq, dtype=torch.bool, device=scores[0].device) | |
for t in range(len(scores)): | |
step_logits, step_tok = scores[t], seqs[:, t + 1] | |
valid = step_tok.ne(pad_id) & (~finished) | |
if valid.any(): | |
step_logprobs = F.log_softmax(step_logits, dim=-1) | |
gather = step_logprobs.gather(1, step_tok.unsqueeze(1)).squeeze(1) | |
sum_logp += torch.where(valid, gather, torch.zeros_like(gather)) | |
count += valid.float() | |
finished |= step_tok.eq(eos_id) | |
count = torch.where(count.eq(0), torch.ones_like(count), count) | |
return [(lp / c).item() for lp, c in zip(sum_logp, count)] | |
# --- UPDATED sampling_generate function (Deep Analysis) --- | |
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None): | |
kwargs = dict( | |
max_length=MAX_TARGET_LENGTH, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=top_n, | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
if int(no_repeat_ngram_size) > 0: | |
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size) | |
if float(repetition_penalty) != 1.0: | |
kwargs["repetition_penalty"] = float(repetition_penalty) | |
if bad_words_ids: | |
kwargs["bad_words_ids"] = bad_words_ids | |
gen = model.generate(**inputs, **kwargs) | |
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen) | |
def normalize_text(s: str) -> str: | |
return " ".join(s.strip().lower().split()) | |
# --- Beam-based quick function (from old script) --- | |
def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device, num_return_sequences: int = 10) -> List[str]: | |
input_text = f"For URL: {url} diversify query: {query}" | |
inputs = tok(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=MAX_TARGET_LENGTH, | |
num_return_sequences=num_return_sequences, | |
num_beams=num_return_sequences * 2, | |
num_beam_groups=num_return_sequences, | |
diversity_penalty=0.5, | |
temperature=0.8, | |
do_sample=False, | |
early_stopping=True, | |
pad_token_id=tok.pad_token_id, | |
eos_token_id=tok.eos_token_id, | |
forced_eos_token_id=tok.eos_token_id, | |
max_new_tokens=MAX_TARGET_LENGTH, | |
) | |
expansions: List[str] = [] | |
for seq in outputs: | |
s = tok.decode(seq, skip_special_tokens=True) | |
if s and normalize_text(s) != normalize_text(query): | |
expansions.append(s) | |
seen = set() | |
uniq = [] | |
for s in expansions: | |
if s not in seen: | |
seen.add(s) | |
uniq.append(s) | |
return uniq | |
# ------------------ STREAMLIT APP ------------------ | |
st.set_page_config( | |
page_title="Query Fan-Out by DEJAN AI", | |
page_icon="🔎", | |
layout="wide" | |
) | |
st.logo( | |
image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png", | |
link="https://dejan.ai/", | |
) | |
tok, model, device = load_model() | |
st.title("Query Fanout Generator") | |
st.markdown("Enter a URL and a query to generate a diverse set of related queries.") | |
# Inputs | |
col1, col2 = st.columns(2) | |
with col1: | |
url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.") | |
with col2: | |
query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.") | |
# Mode + single Run button | |
mode_high_effort = st.toggle("High Effort", value=False, help="On = Deep Analysis (stochastic sampling, large batch). Off = Quick Fan-Out (beam-based).") | |
run_btn = st.button("Generate", type="primary") | |
if run_btn: | |
if mode_high_effort: | |
# ---- Deep Analysis path (sampling, large batches) ---- | |
cfg = GENERATION_CONFIG | |
with st.spinner("Generating queries..."): | |
start_ts = time.time() | |
inputs = build_inputs(tok, url, query, device) | |
all_texts, all_scores = [], [] | |
seen_texts_for_bad_words = set() | |
num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE | |
progress_bar = st.progress(0) | |
for i in range(num_batches): | |
current_seed = cfg["seed"] + i | |
torch.manual_seed(current_seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(current_seed) | |
bad_words_ids = None | |
if seen_texts_for_bad_words: | |
bad_words_ids = tok( | |
list(seen_texts_for_bad_words), | |
add_special_tokens=False, | |
padding=True, | |
truncation=True | |
)["input_ids"] | |
batch_texts, batch_scores = sampling_generate( | |
tok, model, device, inputs, | |
top_n=GENERATION_BATCH_SIZE, | |
temperature=float(cfg["temperature"]), | |
top_p=float(cfg["top_p"]), | |
no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]), | |
repetition_penalty=float(cfg["repetition_penalty"]), | |
bad_words_ids=bad_words_ids | |
) | |
all_texts.extend(batch_texts) | |
all_scores.extend(batch_scores) | |
for txt in batch_texts: | |
if txt: | |
seen_texts_for_bad_words.add(txt) | |
progress_bar.progress((i + 1) / num_batches) | |
# Deduplicate and finalize | |
final_enriched = [] | |
final_seen_normalized = set() | |
for txt, sc in zip(all_texts, all_scores): | |
norm = normalize_text(txt) | |
if norm and norm not in final_seen_normalized and norm != query.lower(): | |
final_seen_normalized.add(norm) | |
final_enriched.append({"logp/len": sc, "text": txt}) | |
if cfg["sort_by"] == "logp/len": | |
final_enriched.sort(key=lambda x: x["logp/len"], reverse=True) | |
final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES] | |
if not final_enriched: | |
st.warning("No queries were generated. Try a different input.") | |
else: | |
output_texts = [item['text'] for item in final_enriched] | |
df = pd.DataFrame(output_texts, columns=["Generated Query"]) | |
df.index = range(1, len(df) + 1) | |
st.dataframe(df, use_container_width=True) | |
else: | |
# ---- Quick Fan-Out path (beam-based, small and simple) ---- | |
with st.spinner("Generating quick fan-out..."): | |
start_time = time.time() | |
expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10) | |
if expansions: | |
df = pd.DataFrame(expansions, columns=["Generated Query"]) | |
df.index = range(1, len(df) + 1) | |
st.dataframe(df, use_container_width=True) | |
else: | |
st.warning("No valid fan-outs generated. Try a different query.") | |