|
import streamlit as st |
|
import torch |
|
from transformers import BertForTokenClassification, BertTokenizerFast |
|
|
|
def load_model(model_name='linkbert.pth'): |
|
model_path = model_name |
|
model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=2) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
def predict_and_annotate(model, tokenizer, text): |
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, return_offsets_mapping=True) |
|
input_ids, attention_mask, offset_mapping = inputs["input_ids"], inputs["attention_mask"], inputs["offset_mapping"] |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
predictions = torch.argmax(outputs.logits, dim=-1) |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist()) |
|
predictions = predictions.squeeze().tolist() |
|
offset_mapping = offset_mapping.squeeze().tolist() |
|
|
|
annotated_text = "" |
|
previous_end = 0 |
|
for offset, prediction in zip(offset_mapping, predictions): |
|
start, end = offset |
|
if start == end: |
|
continue |
|
if prediction == 1: |
|
if start > previous_end: |
|
annotated_text += text[previous_end:start] |
|
annotated_text += f"<u>{text[start:end]}</u>" |
|
else: |
|
if start > previous_end: |
|
annotated_text += text[previous_end:start] |
|
annotated_text += text[start:end] |
|
previous_end = end |
|
annotated_text += text[previous_end:] |
|
|
|
return annotated_text |
|
|
|
|
|
st.title("BERT Token Classification for Anchor Text Prediction") |
|
|
|
|
|
model = load_model('linkbert.pth') |
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
|
|
|
|
user_input = st.text_area("Paste the text you want to analyze:", "Type or paste text here.") |
|
|
|
if st.button("Predict Anchor Texts"): |
|
if user_input: |
|
annotated_text = predict_and_annotate(model, tokenizer, user_input) |
|
st.markdown(annotated_text, unsafe_allow_html=True) |
|
else: |
|
st.write("Please paste some text into the text area.") |
|
|