|
import streamlit as st |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import pandas as pd |
|
from pprint import pprint |
|
|
|
|
|
@st.cache_resource() |
|
def load_trained_model(): |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") |
|
model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") |
|
|
|
id2label = model.config.id2label |
|
|
|
print(f"Can recognise the following labels {id2label}") |
|
|
|
|
|
|
|
ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer) |
|
return ner_pipeline |
|
|
|
|
|
@st.cache_data() |
|
def load_plod_cw_dataset(): |
|
from datasets import load_dataset |
|
dataset = load_dataset("surrey-nlp/PLOD-CW") |
|
return dataset |
|
|
|
def load_random_examples(dataset_name, num_examples=5): |
|
""" |
|
Load random examples from the specified Hugging Face dataset. |
|
Args: |
|
dataset_name (str): The name of the dataset to load. |
|
num_examples (int): The number of random examples to load. |
|
Returns: |
|
pd.DataFrame: A DataFrame containing the random examples. |
|
""" |
|
|
|
|
|
dat = load_plod_cw_dataset() |
|
|
|
|
|
df = pd.DataFrame(dat['test']) |
|
|
|
|
|
random_examples = df.sample(n=1) |
|
|
|
tokens = random_examples.tokens |
|
ner_tags = random_examples.ner_tags |
|
|
|
return pd.DataFrame((tokens, ner_tags)) |
|
|
|
|
|
def render_entities(tokens, entities): |
|
""" |
|
Renders a page with a 2-column table showing the entity corresponding to each token. |
|
""" |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
body { |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f0f0f5; |
|
color: #333333; |
|
} |
|
table { |
|
width: 100%; |
|
border-collapse: collapse; |
|
} |
|
th, td { |
|
padding: 12px; |
|
text-align: left; |
|
border-bottom: 1px solid #dddddd; |
|
} |
|
th { |
|
background-color: #4CAF50; |
|
color: white; |
|
width: 16.66%; |
|
} |
|
tr:hover { |
|
background-color: #f5f5f5; |
|
} |
|
td { |
|
width: 16.66%; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.title("Model predicted Token vs Entities Table") |
|
st.write("This table shows the entity corresponding to each token in a cool and chilled theme.") |
|
|
|
|
|
table_data = {"Token": tokens, "Entity": entities} |
|
st.table(table_data) |
|
|
|
def render_random_examples(): |
|
""" |
|
Render random examples from the PLOD-CW dataset in a Streamlit table. |
|
""" |
|
|
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
body { |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f0f0f5; |
|
color: #333333; |
|
} |
|
table { |
|
width: 100%; |
|
border-collapse: collapse; |
|
} |
|
th, td { |
|
padding: 12px; |
|
text-align: left; |
|
border-bottom: 1px solid #dddddd; |
|
} |
|
th { |
|
background-color: #4CAF50; |
|
color: white; |
|
width: 16.66%; |
|
} |
|
tr:hover { |
|
background-color: #f5f5f5; |
|
} |
|
td { |
|
width: 16.66%; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.title("Random Examples from PLOD-CW") |
|
st.write("This table shows 1 random examples from the PLOD-CW dataset in a cool and chilled theme.") |
|
|
|
|
|
if st.button('Show another set of random examples'): |
|
st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW") |
|
|
|
|
|
if 'random_examples' not in st.session_state: |
|
st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW") |
|
|
|
|
|
st.table(st.session_state['random_examples']) |
|
def predict_using_trained(sentence): |
|
model = load_trained_model() |
|
|
|
entities = model(sentence) |
|
|
|
return entities |
|
|
|
def prep_page(): |
|
model = load_trained_model() |
|
|
|
|
|
|
|
|
|
|
|
st.title("Named Entity Recognition with BERT on PLOD-CW") |
|
st.write("Enter a sentence to see the named entities recognized by the model.") |
|
|
|
|
|
text = st.text_area("Enter your sentence here:") |
|
|
|
|
|
if text: |
|
st.write("Entities recognized:") |
|
entities = model(text) |
|
|
|
pprint(entities) |
|
|
|
|
|
label_colors = { |
|
'B-LF': 'lightblue', |
|
'B-O': 'lightgreen', |
|
'B-AC': 'lightcoral', |
|
'I-LF': 'lightyellow' |
|
} |
|
|
|
|
|
def get_entity_html(text, entities): |
|
html = "<div>" |
|
last_idx = 0 |
|
for entity in entities: |
|
start = entity['start'] |
|
end = entity['end'] |
|
label = entity['entity'] |
|
entity_text = text[start:end] |
|
color = label_colors.get(label, 'lightgray') |
|
|
|
|
|
html += text[last_idx:start].replace(" ", "<br>") |
|
|
|
html += f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin: 5px 0;">{entity_text}</div>' |
|
last_idx = end |
|
|
|
|
|
html += text[last_idx:].replace(" ", "<br>") |
|
html += "</div>" |
|
return html |
|
|
|
|
|
styled_text = get_entity_html(text, entities) |
|
|
|
st.markdown(styled_text, unsafe_allow_html=True) |
|
|
|
render_entities(text, entities) |
|
|
|
render_random_examples() |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
query_params = st.query_params |
|
if 'api' in query_params: |
|
sentence = query_params.get('sentence') |
|
entities = predict_using_trained(sentence) |
|
response = {"sentence" : sentence , "entities" : entities} |
|
pprint(response) |
|
|
|
st.write(response) |
|
else: |
|
prep_page() |
|
|