File size: 6,606 Bytes
a835867
 
 
 
 
 
 
88da8fc
a835867
 
a92f9e3
88da8fc
a92f9e3
88da8fc
 
a92f9e3
 
 
 
 
 
 
a835867
 
a92f9e3
a835867
88da8fc
 
 
 
 
 
 
 
a835867
 
 
 
 
 
 
 
 
a92f9e3
a835867
 
88da8fc
a835867
88da8fc
 
a92f9e3
 
a835867
88da8fc
a835867
 
88da8fc
 
 
 
 
a835867
 
 
88da8fc
 
 
a92f9e3
 
a835867
88da8fc
 
 
 
 
 
 
 
a92f9e3
88da8fc
 
a92f9e3
88da8fc
 
 
 
 
 
 
a92f9e3
88da8fc
 
 
 
 
 
a92f9e3
88da8fc
 
 
 
 
a92f9e3
88da8fc
 
a92f9e3
88da8fc
 
 
a92f9e3
 
88da8fc
 
 
a92f9e3
 
 
a835867
 
 
 
a92f9e3
88da8fc
a835867
88da8fc
a835867
 
88da8fc
a835867
 
 
 
a92f9e3
88da8fc
a92f9e3
 
 
a835867
 
a92f9e3
88da8fc
a92f9e3
 
 
 
 
5ecbaa9
a92f9e3
a835867
 
 
 
88da8fc
 
 
 
 
a835867
 
 
 
 
88da8fc
a835867
 
 
88da8fc
a92f9e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
import trafilatura

# Streamlit config
st.set_page_config(layout="wide", page_title="LinkBERT")

# Model setup (load fully to avoid meta tensors)
MODEL_ID = "dejanseo/LinkBERT-XL"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# Force materialized tensors on CPU, then move — avoids meta tensors
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_ID,
    low_cpu_mem_usage=False,   # important: materialize weights
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
model.to(device)
model.eval()

# Functions
def tokenize_with_indices(text: str):
    encoded = tokenizer.encode_plus(
        text,
        return_offsets_mapping=True,
        add_special_tokens=True,
        truncation=True,
        max_length=512
    )
    return encoded["input_ids"], encoded["offset_mapping"]

def fetch_and_extract_content(url: str):
    downloaded = trafilatura.fetch_url(url)
    if downloaded:
        content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
        return content
    return None

def process_text(inputs: str, confidence_threshold: float):
    max_chunk_length = 512 - 2  # leave room for specials
    words = inputs.split()
    chunk_texts = []
    current_chunk, current_length = [], 0
    for word in words:
        tok_len = len(tokenizer.tokenize(word))
        if tok_len + current_length > max_chunk_length:
            if current_chunk:
                chunk_texts.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = tok_len
        else:
            current_chunk.append(word)
            current_length += tok_len
    if current_chunk:
        chunk_texts.append(" ".join(current_chunk))

    df_data = {"Word": [], "Prediction": [], "Confidence": [], "Start": [], "End": []}
    reconstructed_text = ""
    original_position_offset = 0

    with torch.no_grad():
        for chunk in chunk_texts:
            input_ids, token_offsets = tokenize_with_indices(chunk)
            # Build tensors on correct device; no meta usage
            input_ids_tensor = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)

            outputs = model(input_ids_tensor)
            logits = outputs.logits  # [1, seq_len, num_labels]
            predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
            softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()

            word_info = {}
            for idx, (start, end) in enumerate(token_offsets):
                if idx == 0 or idx == len(token_offsets) - 1:
                    continue  # skip special tokens

                word_start = start
                while word_start > 0 and chunk[word_start - 1] != ' ':
                    word_start -= 1

                if word_start not in word_info:
                    word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}

                conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
                if predictions[idx] == 1 and conf_pct >= confidence_threshold:
                    word_info[word_start]["prediction"] = 1
                word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
                word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))

            last_end = 0
            for word_start in sorted(word_info.keys()):
                word_data = word_info[word_start]
                for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
                    escaped = subtoken_text.replace("$", "\\$")
                    if last_end < subtoken_start:
                        reconstructed_text += chunk[last_end:subtoken_start]
                    if word_data["prediction"] == 1:
                        reconstructed_text += (
                            f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
                        )
                    else:
                        reconstructed_text += escaped
                    last_end = subtoken_end

                    df_data["Word"].append(escaped)
                    df_data["Prediction"].append(word_data["prediction"])
                    df_data["Confidence"].append(word_info[word_start]["confidence"])
                    df_data["Start"].append(subtoken_start + original_position_offset)
                    df_data["End"].append(subtoken_end + original_position_offset)

                original_position_offset += len(chunk) + 1

            reconstructed_text += chunk[last_end:].replace("$", "\\$")

    df_tokens = pd.DataFrame(df_data)
    return reconstructed_text, df_tokens

# UI
st.title("LinkBERT")
st.markdown("""
LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
""")

confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)

tab1, tab2 = st.tabs(["Text Input", "URL Input"])

with tab1:
    user_input = st.text_area("Enter text to process:")
    if st.button("Process Text"):
        highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
        st.markdown(highlighted_text, unsafe_allow_html=True)
        st.dataframe(df_tokens)

with tab2:
    url_input = st.text_input("Enter URL to process:")
    if st.button("Fetch and Process"):
        content = fetch_and_extract_content(url_input)
        if content:
            highlighted_text, df_tokens = process_text(content, confidence_threshold)
            st.markdown(highlighted_text, unsafe_allow_html=True)
            st.dataframe(df_tokens)
        else:
            st.error("Could not fetch content from the URL. Please check the URL and try again.")

st.divider()
st.markdown("""
## Applications of LinkBERT
- **Anchor Text Suggestion**
- **Evaluation of Existing Links**
- **Link Placement Guide**
- **Anchor Text Idea Generator**
- **Spam and Inorganic SEO Detection**

## Training and Performance
LinkBERT was fine-tuned on a dataset of organic web content and editorial links.

[Watch the video](https://www.youtube.com/watch?v=A0ZulyVqjZo)

# Engage Our Team
Interested in using this in an automated pipeline for bulk link prediction?

Please [book an appointment](https://dejanmarketing.com/conference/).
""")