dejanseo commited on
Commit
43b042c
1 Parent(s): a9a3546

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +59 -0
  2. linkbert.pth +3 -0
inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import BertForTokenClassification, BertTokenizerFast # Import BertTokenizerFast
4
+
5
+ def load_model(model_name='linkbert.pth'):
6
+ model_path = model_name
7
+ model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=2)
8
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
9
+ model.eval() # Set the model to inference mode
10
+ return model
11
+
12
+ def predict_and_annotate(model, tokenizer, text):
13
+ # Tokenize the input text with special tokens
14
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, return_offsets_mapping=True)
15
+ input_ids, attention_mask, offset_mapping = inputs["input_ids"], inputs["attention_mask"], inputs["offset_mapping"]
16
+
17
+ with torch.no_grad():
18
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
19
+ predictions = torch.argmax(outputs.logits, dim=-1)
20
+
21
+ tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist())
22
+ predictions = predictions.squeeze().tolist()
23
+ offset_mapping = offset_mapping.squeeze().tolist()
24
+
25
+ annotated_text = ""
26
+ previous_end = 0
27
+ for offset, prediction in zip(offset_mapping, predictions):
28
+ start, end = offset
29
+ if start == end: # Skip special tokens
30
+ continue
31
+ if prediction == 1: # Anchor text
32
+ if start > previous_end:
33
+ annotated_text += text[previous_end:start]
34
+ annotated_text += f"<u>{text[start:end]}</u>"
35
+ else:
36
+ if start > previous_end:
37
+ annotated_text += text[previous_end:start]
38
+ annotated_text += text[start:end]
39
+ previous_end = end
40
+ annotated_text += text[previous_end:] # Append remaining text
41
+
42
+ return annotated_text
43
+
44
+ # Streamlit app setup
45
+ st.title("BERT Token Classification for Anchor Text Prediction")
46
+
47
+ # Load the model and tokenizer
48
+ model = load_model('linkbert.pth')
49
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') # Use BertTokenizerFast
50
+
51
+ # User input text area
52
+ user_input = st.text_area("Paste the text you want to analyze:", "Type or paste text here.")
53
+
54
+ if st.button("Predict Anchor Texts"):
55
+ if user_input:
56
+ annotated_text = predict_and_annotate(model, tokenizer, user_input)
57
+ st.markdown(annotated_text, unsafe_allow_html=True)
58
+ else:
59
+ st.write("Please paste some text into the text area.")
linkbert.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81dc286402b449bf1e0348dbd7f8bb0b64a284f452bd4e0b2bb41ddbac492a24
3
+ size 435654416