import streamlit as st from transformers import pipeline from transformers import AutoTokenizer, AutoModelForTokenClassification import pandas as pd from pprint import pprint @st.cache_resource() def load_trained_model(): tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav") # Mapping labels id2label = model.config.id2label # Print the label mapping print(f"Can recognise the following labels {id2label}") # Load the NER model and tokenizer from Hugging Face #ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english") ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer) return ner_pipeline @st.cache_data() def load_plod_cw_dataset(): from datasets import load_dataset dataset = load_dataset("surrey-nlp/PLOD-CW") return dataset def load_random_examples(dataset_name, num_examples=5): """ Load random examples from the specified Hugging Face dataset. Args: dataset_name (str): The name of the dataset to load. num_examples (int): The number of random examples to load. Returns: pd.DataFrame: A DataFrame containing the random examples. """ # Load the dataset dat = load_plod_cw_dataset() # Convert the dataset to a pandas DataFrame df = pd.DataFrame(dat['test']) # Select random examples random_examples = df.sample(n=1) tokens = random_examples.tokens ner_tags = random_examples.ner_tags return pd.DataFrame((tokens, ner_tags)) def render_entities(tokens, entities): """ Renders a page with a 2-column table showing the entity corresponding to each token. """ # Custom CSS for chilled and cool theme st.markdown(""" """, unsafe_allow_html=True) # Title and description st.title("Model predicted Token vs Entities Table") st.write("This table shows the entity corresponding to each token in a cool and chilled theme.") # Create the table table_data = {"Token": tokens, "Entity": entities} st.table(table_data) def render_random_examples(): """ Render random examples from the PLOD-CW dataset in a Streamlit table. """ # Load random examples # Custom CSS for chilled and cool theme st.markdown(""" """, unsafe_allow_html=True) # Title and description st.title("Random Examples from PLOD-CW") st.write("This table shows 1 random examples from the PLOD-CW dataset in a cool and chilled theme.") # Add a button to select a different set of random samples if st.button('Show another set of random examples'): st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW") # Load random examples if not already loaded if 'random_examples' not in st.session_state: st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW") # Display the table st.table(st.session_state['random_examples']) def predict_using_trained(sentence): model = load_trained_model() entities = model(sentence) return entities def prep_page(): model = load_trained_model() # Streamlit app # Page configuration #st.set_page_config(page_title="NER Token Entities", layout="centered") st.title("Named Entity Recognition with BERT on PLOD-CW") st.write("Enter a sentence to see the named entities recognized by the model.") # Text input text = st.text_area("Enter your sentence here:") # Perform NER and display results if text: st.write("Entities recognized:") entities = model(text) pprint(entities) # Create a dictionary to map entity labels to colors label_colors = { 'B-LF': 'lightblue', 'B-O': 'lightgreen', 'B-AC': 'lightcoral', 'I-LF': 'lightyellow' } # Prepare the HTML output with styled entities def get_entity_html(text, entities): html = "
" last_idx = 0 for entity in entities: start = entity['start'] end = entity['end'] label = entity['entity'] entity_text = text[start:end] color = label_colors.get(label, 'lightgray') # Append the text before the entity html += text[last_idx:start].replace(" ", "
") # Append the entity with styling html += f'
{entity_text}
' last_idx = end # Append any remaining text after the last entity html += text[last_idx:].replace(" ", "
") html += "
" return html # Generate and display the styled HTML styled_text = get_entity_html(text, entities) st.markdown(styled_text, unsafe_allow_html=True) render_entities(text, entities) render_random_examples() if __name__ == '__main__': query_params = st.query_params if 'api' in query_params: sentence = query_params.get('sentence') entities = predict_using_trained(sentence) response = {"sentence" : sentence , "entities" : entities} pprint(response) st.write(response) else: prep_page()