File size: 2,576 Bytes
b3e4e96
 
 
 
 
26dff99
 
 
 
efb23c9
 
26dff99
 
 
 
 
 
 
 
 
efb23c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920a8fe
efb23c9
 
 
26dff99
 
b3e4e96
 
 
 
 
766dac7
b3e4e96
efb23c9
b3e4e96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pandas as pd
from spacy import displacy
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy_streamlit.util import get_html
import streamlit as st
import torch
from transformers import BertTokenizerFast

from model import BertForTokenAndSequenceJointClassification


@st.cache(allow_output_mutation=True)
def load_model():
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
    model = BertForTokenAndSequenceJointClassification.from_pretrained(
            "QCRI/PropagandaTechniquesAnalysis-en-BERT",
             revision="v0.1.0")
    return tokenizer, model

with torch.inference_mode(True):
    tokenizer, model = load_model()

    st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")

    input = st.text_area('Input', """\
    In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
    """)

    inputs = tokenizer.encode_plus(input, return_tensors="pt")
    outputs = model(**inputs)
    sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
    sequence_class = model.sequence_tags[sequence_class_index[0]]
    token_class_index = torch.argmax(outputs.token_logits, dim=-1)
    tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
    tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]

columns = st.columns(len(outputs.sequence_logits.flatten()))
for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()):
    col.metric(sequence_tag, '%.2f' % logit.item())


spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]

doc = Doc(Vocab(strings=set(tokens)),
          words=tokens,
          spaces=spaces,
          ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])

labels = model.token_tags[2:]

label_select = st.multiselect(
    "Tags",
    options=labels,
    default=labels,
    key=f"tags_ner_label_select",
)
html = displacy.render(
    doc, style="ent", options={"ents": label_select, "colors": {}}
)
style = "<style>mark.entity { display: inline-block }</style>"
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)

attrs = ["text", "label_", "start", "end", "start_char", "end_char"]
data = [
    [str(getattr(ent, attr)) for attr in attrs]
    for ent in doc.ents
    if ent.label_ in label_select
]
if data:
    df = pd.DataFrame(data, columns=attrs)
    st.dataframe(df)