linkbert / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
a92f9e3 verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
import trafilatura
# Streamlit config
st.set_page_config(layout="wide", page_title="LinkBERT")
# Model setup (load fully to avoid meta tensors)
MODEL_ID = "dejanseo/LinkBERT-XL"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# Force materialized tensors on CPU, then move — avoids meta tensors
model = AutoModelForTokenClassification.from_pretrained(
MODEL_ID,
low_cpu_mem_usage=False, # important: materialize weights
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model.to(device)
model.eval()
# Functions
def tokenize_with_indices(text: str):
encoded = tokenizer.encode_plus(
text,
return_offsets_mapping=True,
add_special_tokens=True,
truncation=True,
max_length=512
)
return encoded["input_ids"], encoded["offset_mapping"]
def fetch_and_extract_content(url: str):
downloaded = trafilatura.fetch_url(url)
if downloaded:
content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
return content
return None
def process_text(inputs: str, confidence_threshold: float):
max_chunk_length = 512 - 2 # leave room for specials
words = inputs.split()
chunk_texts = []
current_chunk, current_length = [], 0
for word in words:
tok_len = len(tokenizer.tokenize(word))
if tok_len + current_length > max_chunk_length:
if current_chunk:
chunk_texts.append(" ".join(current_chunk))
current_chunk = [word]
current_length = tok_len
else:
current_chunk.append(word)
current_length += tok_len
if current_chunk:
chunk_texts.append(" ".join(current_chunk))
df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []}
reconstructed_text = ""
original_position_offset = 0
with torch.no_grad():
for chunk in chunk_texts:
input_ids, token_offsets = tokenize_with_indices(chunk)
# Build tensors on correct device; no meta usage
input_ids_tensor = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
outputs = model(input_ids_tensor)
logits = outputs.logits # [1, seq_len, num_labels]
predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
word_info = {}
for idx, (start, end) in enumerate(token_offsets):
if idx == 0 or idx == len(token_offsets) - 1:
continue # skip special tokens
word_start = start
while word_start > 0 and chunk[word_start - 1] != ' ':
word_start -= 1
if word_start not in word_info:
word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
if predictions[idx] == 1 and conf_pct >= confidence_threshold:
word_info[word_start]["prediction"] = 1
word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
last_end = 0
for word_start in sorted(word_info.keys()):
word_data = word_info[word_start]
for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
escaped = subtoken_text.replace("$", "\\$")
if last_end < subtoken_start:
reconstructed_text += chunk[last_end:subtoken_start]
if word_data["prediction"] == 1:
reconstructed_text += (
f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
)
else:
reconstructed_text += escaped
last_end = subtoken_end
df_data["Word"].append(escaped)
df_data["Prediction"].append(word_data["prediction"])
df_data["Confidence"].append(word_info[word_start]["confidence"])
df_data["Start"].append(subtoken_start + original_position_offset)
df_data["End"].append(subtoken_end + original_position_offset)
original_position_offset += len(chunk) + 1
reconstructed_text += chunk[last_end:].replace("$", "\\$")
df_tokens = pd.DataFrame(df_data)
return reconstructed_text, df_tokens
# UI
st.title("LinkBERT")
st.markdown("""
LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
""")
confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
with tab1:
user_input = st.text_area("Enter text to process:")
if st.button("Process Text"):
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
st.markdown(highlighted_text, unsafe_allow_html=True)
st.dataframe(df_tokens)
with tab2:
url_input = st.text_input("Enter URL to process:")
if st.button("Fetch and Process"):
content = fetch_and_extract_content(url_input)
if content:
highlighted_text, df_tokens = process_text(content, confidence_threshold)
st.markdown(highlighted_text, unsafe_allow_html=True)
st.dataframe(df_tokens)
else:
st.error("Could not fetch content from the URL. Please check the URL and try again.")
st.divider()
st.markdown("""
## Applications of LinkBERT
- **Anchor Text Suggestion**
- **Evaluation of Existing Links**
- **Link Placement Guide**
- **Anchor Text Idea Generator**
- **Spam and Inorganic SEO Detection**
## Training and Performance
LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)
# Engage Our Team
Interested in using this in an automated pipeline for bulk link prediction?
Please [book an appointment](https://dejanmarketing.com/conference/).
""")