import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForTokenClassification import pandas as pd import trafilatura # Streamlit config st.set_page_config(layout="wide", page_title="LinkBERT") # Model setup (load fully to avoid meta tensors) MODEL_ID = "dejanseo/LinkBERT-XL" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) # Force materialized tensors on CPU, then move — avoids meta tensors model = AutoModelForTokenClassification.from_pretrained( MODEL_ID, low_cpu_mem_usage=False, # important: materialize weights torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) model.to(device) model.eval() # Functions def tokenize_with_indices(text: str): encoded = tokenizer.encode_plus( text, return_offsets_mapping=True, add_special_tokens=True, truncation=True, max_length=512 ) return encoded["input_ids"], encoded["offset_mapping"] def fetch_and_extract_content(url: str): downloaded = trafilatura.fetch_url(url) if downloaded: content = trafilatura.extract(downloaded, include_comments=False, include_tables=False) return content return None def process_text(inputs: str, confidence_threshold: float): max_chunk_length = 512 - 2 # leave room for specials words = inputs.split() chunk_texts = [] current_chunk, current_length = [], 0 for word in words: tok_len = len(tokenizer.tokenize(word)) if tok_len + current_length > max_chunk_length: if current_chunk: chunk_texts.append(" ".join(current_chunk)) current_chunk = [word] current_length = tok_len else: current_chunk.append(word) current_length += tok_len if current_chunk: chunk_texts.append(" ".join(current_chunk)) df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []} reconstructed_text = "" original_position_offset = 0 with torch.no_grad(): for chunk in chunk_texts: input_ids, token_offsets = tokenize_with_indices(chunk) # Build tensors on correct device; no meta usage input_ids_tensor = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) outputs = model(input_ids_tensor) logits = outputs.logits # [1, seq_len, num_labels] predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist() softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist() word_info = {} for idx, (start, end) in enumerate(token_offsets): if idx == 0 or idx == len(token_offsets) - 1: continue # skip special tokens word_start = start while word_start > 0 and chunk[word_start - 1] != ' ': word_start -= 1 if word_start not in word_info: word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []} conf_pct = softmax_scores[idx][predictions[idx]] * 100.0 if predictions[idx] == 1 and conf_pct >= confidence_threshold: word_info[word_start]["prediction"] = 1 word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct) word_info[word_start]["subtokens"].append((start, end, chunk[start:end])) last_end = 0 for word_start in sorted(word_info.keys()): word_data = word_info[word_start] for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]: escaped = subtoken_text.replace("$", "\\$") if last_end < subtoken_start: reconstructed_text += chunk[last_end:subtoken_start] if word_data["prediction"] == 1: reconstructed_text += ( f"{escaped}" ) else: reconstructed_text += escaped last_end = subtoken_end df_data["Word"].append(escaped) df_data["Prediction"].append(word_data["prediction"]) df_data["Confidence"].append(word_info[word_start]["confidence"]) df_data["Start"].append(subtoken_start + original_position_offset) df_data["End"].append(subtoken_end + original_position_offset) original_position_offset += len(chunk) + 1 reconstructed_text += chunk[last_end:].replace("$", "\\$") df_tokens = pd.DataFrame(df_data) return reconstructed_text, df_tokens # UI st.title("LinkBERT") st.markdown(""" LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions. """) confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50) tab1, tab2 = st.tabs(["Text Input", "URL Input"]) with tab1: user_input = st.text_area("Enter text to process:") if st.button("Process Text"): highlighted_text, df_tokens = process_text(user_input, confidence_threshold) st.markdown(highlighted_text, unsafe_allow_html=True) st.dataframe(df_tokens) with tab2: url_input = st.text_input("Enter URL to process:") if st.button("Fetch and Process"): content = fetch_and_extract_content(url_input) if content: highlighted_text, df_tokens = process_text(content, confidence_threshold) st.markdown(highlighted_text, unsafe_allow_html=True) st.dataframe(df_tokens) else: st.error("Could not fetch content from the URL. Please check the URL and try again.") st.divider() st.markdown(""" ## Applications of LinkBERT - **Anchor Text Suggestion** - **Evaluation of Existing Links** - **Link Placement Guide** - **Anchor Text Idea Generator** - **Spam and Inorganic SEO Detection** ## Training and Performance LinkBERT was fine-tuned on a dataset of organic web content and editorial links. [Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo) # Engage Our Team Interested in using this in an automated pipeline for bulk link prediction? Please [book an appointment](https://dejanmarketing.com/conference/). """)