import random from typing import AnyStr import streamlit as st from bs4 import BeautifulSoup import numpy as np import base64 from spacy_streamlit.util import get_svg from custom_renderer import render_sentence_custom from flair.data import Sentence from flair.models import SequenceTagger import spacy from spacy import displacy from spacy_streamlit import visualize_parser from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import pipeline import os from transformers_interpret import SequenceClassificationExplainer # Map model names to URLs model_names_to_URLs = { 'ml6team/distilbert-base-dutch-cased-toxic-comments': 'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments', 'ml6team/robbert-dutch-base-toxic-comments': 'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments', } about_page_markdown = f"""# ๐Ÿคฌ Dutch Toxic Comment Detection Space Made by [ML6](https://ml6.eu/). Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret). """ regular_emojis = [ '๐Ÿ˜', '๐Ÿ™‚', '๐Ÿ‘ถ', '๐Ÿ˜‡', ] undecided_emojis = [ '๐Ÿคจ', '๐Ÿง', '๐Ÿฅธ', '๐Ÿฅด', '๐Ÿคท', ] potty_mouth_emojis = [ '๐Ÿค', '๐Ÿ‘ฟ', '๐Ÿ˜ก', '๐Ÿคฌ', 'โ˜ ๏ธ', 'โ˜ฃ๏ธ', 'โ˜ข๏ธ', ] # Page setup st.set_page_config( page_title="Toxic Comment Detection Space", page_icon="๐Ÿคฌ", layout="centered", initial_sidebar_state="auto", menu_items={ 'Get help': None, 'Report a bug': None, 'About': about_page_markdown, } ) # Model setup @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False) def load_pipeline(model_name): with st.spinner('Loading model (this might take a while)...'): toxicity_pipeline = pipeline( 'text-classification', model=model_name, tokenizer=model_name) cls_explainer = SequenceClassificationExplainer( toxicity_pipeline.model, toxicity_pipeline.tokenizer) return toxicity_pipeline, cls_explainer # Auxiliary functions def format_explainer_html(html_string): """Extract tokens with attribution-based background color.""" inside_token_prefix = '##' soup = BeautifulSoup(html_string, 'html.parser') p = soup.new_tag('p', attrs={'style': 'color: black; background-color: white;'}) # Select token elements and remove model specific tokens current_word = None for token in soup.find_all('td')[-1].find_all('mark')[1:-1]: text = token.font.text.strip() if text.startswith(inside_token_prefix): text = text[len(inside_token_prefix):] else: # Create a new span for each word (sequence of sub-tokens) if current_word is not None: p.append(current_word) p.append(' ') current_word = soup.new_tag('span') token.string = text token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;" current_word.append(token) # Add last word p.append(current_word) # Add left and right-padding to each word for span in p.find_all('span'): span.find_all('mark')[0].attrs['style'] = ( f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;") span.find_all('mark')[-1].attrs['style'] = ( f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;") return p def list_all_article_names() -> list: filenames = [] for file in os.listdir('./sample-articles/'): if file.endswith('.txt'): filenames.append(file.replace('.txt', '')) return filenames def fetch_article_contents(filename: str) -> AnyStr: with open(f'./sample-articles/{filename.lower()}.txt', 'r') as f: data = f.read() return data def fetch_summary_contents(filename: str) -> AnyStr: with open(f'./sample-summaries/{filename.lower()}.txt', 'r') as f: data = f.read() return data def classify_comment(comment, selected_model): """Classify the given comment and augment with additional information.""" toxicity_pipeline, cls_explainer = load_pipeline(selected_model) result = toxicity_pipeline(comment)[0] result['model_name'] = selected_model # Add explanation result['word_attribution'] = cls_explainer(comment, class_name="non-toxic") result['visualitsation_html'] = cls_explainer.visualize()._repr_html_() result['tokens_with_background'] = format_explainer_html( result['visualitsation_html']) # Choose emoji reaction label, score = result['label'], result['score'] if label == 'toxic' and score > 0.1: emoji = random.choice(potty_mouth_emojis) elif label in ['non_toxic', 'non-toxic'] and score > 0.1: emoji = random.choice(regular_emojis) else: emoji = random.choice(undecided_emojis) result.update({'text': comment, 'emoji': emoji}) # Add result to session st.session_state.results.append(result) # Start session if 'results' not in st.session_state: st.session_state.results = [] # Page # st.title('๐Ÿคฌ Dutch Toxic Comment Detection') # st.markdown("""This demo showcases two Dutch toxic comment detection models.""") # # # Introduction # st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments. # The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""") # st.markdown(f"""For a more comprehensive overview of the models check out their model card on ๐Ÿค— Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}). # """) # st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision: # # red # # tokens indicate toxicity whereas # # green # tokens indicate the opposite. # # Try it yourself! ๐Ÿ‘‡""", # unsafe_allow_html=True) # Demo # with st.form("dutch-toxic-comment-detection-input", clear_on_submit=True): # selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(), # )#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False) # text = st.text_area( # label='Enter the comment you want to classify below (in Dutch):') # _, rightmost_col = st.columns([6,1]) # submitted = rightmost_col.form_submit_button("Classify", # help="Classify comment") # TODO: should probably set a minimum length of article or something selected_article = st.selectbox('Select an article or provide your own:', list_all_article_names()) # index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False) st.session_state.article_text = fetch_article_contents(selected_article) article_text = st.text_area( label='Full article text', value=st.session_state.article_text, height=250 ) # _, rightmost_col = st.columns([5, 1]) # get_summary = rightmost_col.button("Generate summary", # help="Generate summary for the given article text") def display_summary(article_name: str): st.subheader("Generated summary") # st.markdown("######") summary_content = fetch_summary_contents(article_name) soup = BeautifulSoup(summary_content, features="html.parser") HTML_WRAPPER = """
{}
""" st.session_state.summary_output = HTML_WRAPPER.format(soup) st.write(st.session_state.summary_output, unsafe_allow_html=True) # TODO: this functionality can be cached (e.g. by storing html file output) if wanted (or just store list of entities idk) def get_and_compare_entities_spacy(article_name: str): nlp = spacy.load('en_core_web_lg') article_content = fetch_article_contents(article_name) doc = nlp(article_content) # entities_article = doc.ents entities_article = [] for entity in doc.ents: entities_article.append(str(entity)) summary_content = fetch_summary_contents(article_name) doc = nlp(summary_content) # entities_summary = doc.ents entities_summary = [] for entity in doc.ents: entities_summary.append(str(entity)) matched_entities = [] unmatched_entities = [] for entity in entities_summary: # TODO: currently substring matching but probably should do embedding method or idk? if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): matched_entities.append(entity) else: unmatched_entities.append(entity) # print(entities_article) # print(entities_summary) return matched_entities, unmatched_entities def get_and_compare_entities_flair(article_name: str): nlp = spacy.load('en_core_web_sm') tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast") article_content = fetch_article_contents(article_name) doc = nlp(article_content) entities_article = [] sentences = list(doc.sents) for sentence in sentences: sentence_entities = Sentence(str(sentence)) tagger.predict(sentence_entities) for entity in sentence_entities.get_spans('ner'): entities_article.append(entity.text) summary_content = fetch_summary_contents(article_name) doc = nlp(summary_content) entities_summary = [] sentences = list(doc.sents) for sentence in sentences: sentence_entities = Sentence(str(sentence)) tagger.predict(sentence_entities) for entity in sentence_entities.get_spans('ner'): entities_summary.append(entity.text) matched_entities = [] unmatched_entities = [] for entity in entities_summary: # TODO: currently substring matching but probably should do embedding method or idk? if any(entity.lower() in substring_entity.lower() for substring_entity in entities_article): matched_entities.append(entity) else: unmatched_entities.append(entity) # print(entities_article) # print(entities_summary) return matched_entities, unmatched_entities def highlight_entities(article_name: str): st.subheader("Match entities with article") # st.markdown("####") summary_content = fetch_summary_contents(article_name) markdown_start_red = "" markdown_start_green = "" markdown_end = "" matched_entities, unmatched_entities = get_and_compare_entities_spacy(article_name) for entity in matched_entities: summary_content = summary_content.replace(entity, markdown_start_green + entity + markdown_end) for entity in unmatched_entities: summary_content = summary_content.replace(entity, markdown_start_red + entity + markdown_end) soup = BeautifulSoup(summary_content, features="html.parser") HTML_WRAPPER = """
{}
""" st.write(HTML_WRAPPER.format(soup), unsafe_allow_html=True) def render_dependency_parsing(text: str): nlp = spacy.load('en_core_web_sm') #doc = nlp(text) # st.write(displacy.render(doc, style='dep')) #sentence_spans = list(doc.sents) # dep_svg = displacy.serve(sentence_spans, style="dep") # dep_svg = displacy.render(doc, style="dep", jupyter = False, # options = {"compact" : False,}) # st.image(dep_svg, width = 50,use_column_width=True) #visualize_parser(doc) #docs = [doc] #split_sents = True #docs = [span.as_doc() for span in doc.sents] if split_sents else [doc] #for sent in docs: html = render_sentence_custom(text) # Double newlines seem to mess with the rendering html = html.replace("\n\n", "\n") st.write(get_svg(html), unsafe_allow_html=True) #st.image(html, width=50, use_column_width=True) def check_dependency(text): tagger = SequenceTagger.load("flair/ner-english-ontonotes-fast") nlp = spacy.load('en_core_web_lg') doc = nlp(text) tok_l = doc.to_json()['tokens'] # all_deps = [] all_deps = "" sentences = list(doc.sents) for sentence in sentences: all_entities = [] # # ENTITIES WITH SPACY: for entity in sentence.ents: all_entities.append(str(entity)) # # ENTITIES WITH FLAIR: sentence_entities = Sentence(str(sentence)) tagger.predict(sentence_entities) for entity in sentence_entities.get_spans('ner'): all_entities.append(entity.text) # ENTITIES WITH XLM ROBERTA # entities_xlm = [entity["word"] for entity in ner_model(str(sentence))] # for entity in entities_xlm: # all_entities.append(str(entity)) start_id = sentence.start end_id = sentence.end for t in tok_l: if t["id"] < start_id or t["id"] > end_id: continue head = tok_l[t['head']] if t['dep'] == 'amod': object_here = text[t['start']:t['end']] object_target = text[head['start']:head['end']] # ONE NEEDS TO BE ENTITY if (object_here in all_entities): # all_deps.append(f"'{text[t['start']:t['end']]}' is {t['dep']} of '{text[head['start']:head['end']]}'") all_deps = all_deps.join(str(sentence)) elif (object_target in all_entities): # all_deps.append(f"'{text[t['start']:t['end']]}' is {t['dep']} of '{text[head['start']:head['end']]}'") all_deps = all_deps.join(str(sentence)) else: continue return all_deps with st.form("article-input"): left_column, _ = st.columns([1, 1]) get_summary = left_column.form_submit_button("Generate summary", help="Generate summary for the given article text") # Listener if get_summary: if article_text: with st.spinner('Generating summary...'): # classify_comment(article_text, selected_model) display_summary(selected_article) else: st.error('**Error**: No comment to classify. Please provide a comment.') # Entity part with st.form("Entity-part"): left_column, _ = st.columns([1, 1]) draw_entities = left_column.form_submit_button("Draw Entities", help="Draw Entities") if draw_entities: with st.spinner("Drawing entities..."): highlight_entities(selected_article) with st.form("Dependency-usage"): left_column, _ = st.columns([1, 1]) parsing = left_column.form_submit_button("Dependency parsing", help="Dependency parsing") if parsing: with st.spinner("Doing dependency parsing..."): render_dependency_parsing(check_dependency(fetch_summary_contents(selected_article))) # Results # if 'results' in st.session_state and st.session_state.results: # first = True # for result in st.session_state.results[::-1]: # if not first: # st.markdown("---") # st.markdown(f"Text:\n> {result['text']}") # col_1, col_2, col_3 = st.columns([1,2,2]) # col_1.metric(label='', value=f"{result['emoji']}") # col_2.metric(label='Label', value=f"{result['label']}") # col_3.metric(label='Score', value=f"{result['score']:.3f}") # st.markdown(f"Token Attribution:\n{result['tokens_with_background']}", # unsafe_allow_html=True) # st.caption(f"Model: {result['model_name']}") # first = False