|
import streamlit as st |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import pandas as pd |
|
import trafilatura |
|
|
|
|
|
st.set_page_config(layout="wide", page_title="LinkBERT") |
|
|
|
|
|
MODEL_ID = "dejanseo/LinkBERT-XL" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
|
model = AutoModelForTokenClassification.from_pretrained( |
|
MODEL_ID, |
|
low_cpu_mem_usage=False, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def tokenize_with_indices(text: str): |
|
encoded = tokenizer.encode_plus( |
|
text, |
|
return_offsets_mapping=True, |
|
add_special_tokens=True, |
|
truncation=True, |
|
max_length=512 |
|
) |
|
return encoded["input_ids"], encoded["offset_mapping"] |
|
|
|
def fetch_and_extract_content(url: str): |
|
downloaded = trafilatura.fetch_url(url) |
|
if downloaded: |
|
content = trafilatura.extract(downloaded, include_comments=False, include_tables=False) |
|
return content |
|
return None |
|
|
|
def process_text(inputs: str, confidence_threshold: float): |
|
max_chunk_length = 512 - 2 |
|
words = inputs.split() |
|
chunk_texts = [] |
|
current_chunk, current_length = [], 0 |
|
for word in words: |
|
tok_len = len(tokenizer.tokenize(word)) |
|
if tok_len + current_length > max_chunk_length: |
|
if current_chunk: |
|
chunk_texts.append(" ".join(current_chunk)) |
|
current_chunk = [word] |
|
current_length = tok_len |
|
else: |
|
current_chunk.append(word) |
|
current_length += tok_len |
|
if current_chunk: |
|
chunk_texts.append(" ".join(current_chunk)) |
|
|
|
df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []} |
|
reconstructed_text = "" |
|
original_position_offset = 0 |
|
|
|
with torch.no_grad(): |
|
for chunk in chunk_texts: |
|
input_ids, token_offsets = tokenize_with_indices(chunk) |
|
|
|
input_ids_tensor = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) |
|
|
|
outputs = model(input_ids_tensor) |
|
logits = outputs.logits |
|
predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist() |
|
softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist() |
|
|
|
word_info = {} |
|
for idx, (start, end) in enumerate(token_offsets): |
|
if idx == 0 or idx == len(token_offsets) - 1: |
|
continue |
|
|
|
word_start = start |
|
while word_start > 0 and chunk[word_start - 1] != ' ': |
|
word_start -= 1 |
|
|
|
if word_start not in word_info: |
|
word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []} |
|
|
|
conf_pct = softmax_scores[idx][predictions[idx]] * 100.0 |
|
if predictions[idx] == 1 and conf_pct >= confidence_threshold: |
|
word_info[word_start]["prediction"] = 1 |
|
word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct) |
|
word_info[word_start]["subtokens"].append((start, end, chunk[start:end])) |
|
|
|
last_end = 0 |
|
for word_start in sorted(word_info.keys()): |
|
word_data = word_info[word_start] |
|
for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]: |
|
escaped = subtoken_text.replace("$", "\\$") |
|
if last_end < subtoken_start: |
|
reconstructed_text += chunk[last_end:subtoken_start] |
|
if word_data["prediction"] == 1: |
|
reconstructed_text += ( |
|
f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>" |
|
) |
|
else: |
|
reconstructed_text += escaped |
|
last_end = subtoken_end |
|
|
|
df_data["Word"].append(escaped) |
|
df_data["Prediction"].append(word_data["prediction"]) |
|
df_data["Confidence"].append(word_info[word_start]["confidence"]) |
|
df_data["Start"].append(subtoken_start + original_position_offset) |
|
df_data["End"].append(subtoken_end + original_position_offset) |
|
|
|
original_position_offset += len(chunk) + 1 |
|
|
|
reconstructed_text += chunk[last_end:].replace("$", "\\$") |
|
|
|
df_tokens = pd.DataFrame(df_data) |
|
return reconstructed_text, df_tokens |
|
|
|
|
|
st.title("LinkBERT") |
|
st.markdown(""" |
|
LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions. |
|
""") |
|
|
|
confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50) |
|
|
|
tab1, tab2 = st.tabs(["Text Input", "URL Input"]) |
|
|
|
with tab1: |
|
user_input = st.text_area("Enter text to process:") |
|
if st.button("Process Text"): |
|
highlighted_text, df_tokens = process_text(user_input, confidence_threshold) |
|
st.markdown(highlighted_text, unsafe_allow_html=True) |
|
st.dataframe(df_tokens) |
|
|
|
with tab2: |
|
url_input = st.text_input("Enter URL to process:") |
|
if st.button("Fetch and Process"): |
|
content = fetch_and_extract_content(url_input) |
|
if content: |
|
highlighted_text, df_tokens = process_text(content, confidence_threshold) |
|
st.markdown(highlighted_text, unsafe_allow_html=True) |
|
st.dataframe(df_tokens) |
|
else: |
|
st.error("Could not fetch content from the URL. Please check the URL and try again.") |
|
|
|
st.divider() |
|
st.markdown(""" |
|
## Applications of LinkBERT |
|
- **Anchor Text Suggestion** |
|
- **Evaluation of Existing Links** |
|
- **Link Placement Guide** |
|
- **Anchor Text Idea Generator** |
|
- **Spam and Inorganic SEO Detection** |
|
|
|
## Training and Performance |
|
LinkBERT was fine-tuned on a dataset of organic web content and editorial links. |
|
|
|
[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo) |
|
|
|
# Engage Our Team |
|
Interested in using this in an automated pipeline for bulk link prediction? |
|
|
|
Please [book an appointment](https://dejanmarketing.com/conference/). |
|
""") |
|
|