import streamlit as st import pandas as pd import plotly.express as px import plotly.graph_objects as go from st_aggrid import AgGrid from st_aggrid.grid_options_builder import GridOptionsBuilder from st_aggrid.shared import JsCode from st_aggrid.shared import GridUpdateMode from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM import torch import numpy as np import json from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead import pytorch_lightning as pl from pathlib import Path # Defining some functions for caching purpose by streamlit class TranslationModel(pl.LightningModule): def __init__(self): super().__init__() self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True) @st.experimental_singleton def loadFineTunedJaEn_NMT_Model(): save_dest = Path('model') save_dest.mkdir(exist_ok=True) f_checkpoint = Path("model/best-checkpoint.ckpt") if not f_checkpoint.exists(): with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"): from GD_download import download_file_from_google_drive download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint) trained_model = TranslationModel.load_from_checkpoint(f_checkpoint) return trained_model @st.experimental_singleton def getJpEn_Tokenizers(): try: with st.spinner("Downloading English and Japanese Transformer Tokenizers"): ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en") en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") except: st.error("Issue with downloading tokenizers") return ja_tokenizer, en_tokenizer st.set_page_config(layout="wide") st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers") st.sidebar.subheader("自然言語処理 トピック") topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"]) st.write("-" * 5) jp_review_text = None #JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/' if topic == "Sentiment Analysis": st.markdown( "

Transfer Learning based Japanese Sentiments Analysis using BERT

", unsafe_allow_html=True) st.markdown( "

Japanese Amazon Reviews Data (日本のAmazonレビューデータ)

", unsafe_allow_html=True) amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000] cellstyle_jscode = JsCode( """ function(params) { if (params.value.includes('positive')) { return { 'color': 'black', 'backgroundColor': '#32CD32' } } else { return { 'color': 'black', 'backgroundColor': '#FF7F7F' } } }; """ ) st.write('', unsafe_allow_html=True) st.write('', unsafe_allow_html=True) choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review")) SELECT_ONE_REVIEW = "Choose a review from the dataframe below" WRITE_REVIEW = "Manually write review" gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews) gb.configure_column("sentiment", cellStyle=cellstyle_jscode) gb.configure_pagination() if choose == SELECT_ONE_REVIEW: gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) gridOptions = gb.build() if choose == SELECT_ONE_REVIEW: jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material', enable_enterprise_modules=True, allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.") if len(jp_review_choice['selected_rows']) != 0: jp_review_text = jp_review_choice['selected_rows'][0]['review'] st.markdown( "

Selected Review in JSON (JSONで選択されたレビュー)

", unsafe_allow_html=True) st.write(jp_review_choice['selected_rows']) if choose == WRITE_REVIEW: AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material', enable_enterprise_modules=True, allow_unsafe_jscode=True) with open("test_reviews_jp.csv", "rb") as file: st.download_button(label="Download Additional Japanese Reviews", data=file, file_name="Additional Japanese Reviews.csv") st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.") sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない" jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area", value=sample_japanese_review_input) if len(jp_review_text) == 0: st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.") if jp_review_text: st.markdown( "

Sentence-Piece based Japanese Tokenizer using RoBERTA

", unsafe_allow_html=True) tokens_column, tokenID_column = st.columns(2) tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base') tokens = tokenizer.tokenize(jp_review_text) token_ids = tokenizer.convert_tokens_to_ids(tokens) with tokens_column: token_expander = st.expander("Expand to see the tokens", expanded=False) with token_expander: st.write(tokens) with tokenID_column: tokenID_expander = st.expander("Expand to see the token IDs", expanded=False) with tokenID_expander: st.write(token_ids) st.markdown( "

Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor

", unsafe_allow_html=True) encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'), add_special_tokens=True, return_attention_mask=True, padding=True, max_length=200, return_tensors='pt', truncation=True) input_ids = encoded_data['input_ids'] attention_masks = encoded_data['attention_mask'] input_ids_column, attention_masks_column = st.columns(2) with input_ids_column: input_ids_expander = st.expander("Expand to see the input IDs tensor") with input_ids_expander: st.write(input_ids) with attention_masks_column: attention_masks_expander = st.expander("Expand to see the attention mask tensor") with attention_masks_expander: st.write(attention_masks) st.markdown( "

Predict Sentiment of review using Fine-Tuned Japanese BERT

", unsafe_allow_html=True) label_dict = {'positive': 1, 'negative': 0} if st.button("Predict Sentiment"): with st.spinner("Wait.."): predictions = [] model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn", num_labels=len(label_dict), output_attentions=False, output_hidden_states=False) #model.load_state_dict( # torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt', # map_location=torch.device('cpu'))) model.load_state_dict( torch.load('reviewSentiments_jp.pt', map_location=torch.device('cpu'))) inputs = { 'input_ids': input_ids, 'attention_mask': attention_masks } with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits logits = logits.detach().cpu().numpy() scores = 1 / (1 + np.exp(-1 * logits)) result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]} result_col,graph_col = st.columns(2) with result_col: st.write(result) with graph_col: fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']]) fig.update_layout(title="Probability distribution of Sentiment for the given text",\ yaxis_title="Probability (確率)") fig.update_traces(marker_color=['#FF7F7F','#32CD32']) st.plotly_chart(fig) elif topic == "Text Summarization": st.markdown( "

Summarizing Japanese News Article using multi-Lingual T5 (mT5)

", unsafe_allow_html=True) st.markdown( "

Japanese News Article Data

", unsafe_allow_html=True) news_articles = pd.read_csv("jp_news_articles.csv").sample(frac=0.75, random_state=42) gb = GridOptionsBuilder.from_dataframe(news_articles) gb.configure_pagination() gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) gridOptions = gb.build() jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material', enable_enterprise_modules=True, allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) # WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip())) if len(jp_article['selected_rows']) == 0: st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom") else: article_text = jp_article['selected_rows'][0]['News Articles'] text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500) summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5) if text and st.button("Summarize it! (要約しよう)"): waitPlaceholder = st.image("wait.gif") summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum" tokenizer = AutoTokenizer.from_pretrained(summarization_model_name ) model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name ) input_ids = tokenizer( article_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512 )["input_ids"] output_ids = model.generate( input_ids=input_ids, max_length=summary_length, no_repeat_ngram_size=2, num_beams=4 )[0] summary = tokenizer.decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) waitPlaceholder.empty() st.markdown( "

Summary (要約文)

", unsafe_allow_html=True) st.write(summary) elif topic == "Japanese to English Translation": st.markdown( "

Japanese to English translation (for short sentences)

", unsafe_allow_html=True) st.markdown( "

Business Scene Dialog Japanese-English Corpus

", unsafe_allow_html=True) st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo") link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)' st.markdown(link, unsafe_allow_html=True) bsd_more_info = st.expander(label="Expand to get more information on data and training report") with bsd_more_info: st.markdown( "

Training Dataset

", unsafe_allow_html=True) st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs") st.markdown( "

Training Report

", unsafe_allow_html=True) st.write( "The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)") with open("./BSD_ja-en_val.json", encoding='utf-8') as f: bsd_sample_data = json.load(f) en, ja = [], [] for i in range(len(bsd_sample_data)): for j in range(len(bsd_sample_data[i]['conversation'])): en.append(bsd_sample_data[i]['conversation'][j]['en_sentence']) ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence']) df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en}) gb = GridOptionsBuilder.from_dataframe(df) gb.configure_pagination() gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False) gridOptions = gb.build() translation_text = AgGrid(df, gridOptions=gridOptions, theme='material', enable_enterprise_modules=True, allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED) if len(translation_text['selected_rows']) != 0: bsd_jp = translation_text['selected_rows'][0]['Japanese'] st.markdown( "

Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)

", unsafe_allow_html=True) st.write(bsd_jp) if st.button("Translate"): ja_tokenizer, en_tokenizer = getJpEn_Tokenizers() trained_model = loadFineTunedJaEn_NMT_Model() trained_model.freeze() def translate(text): text_encoding = ja_tokenizer( text, max_length=100, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) generated_ids = trained_model.model.generate( input_ids=text_encoding['input_ids'], attention_mask=text_encoding['attention_mask'], max_length=100, num_beams=2, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True ) preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generated_ids] return "".join(preds)[5:] st.markdown( "

Translated Dialog in English (英語の翻訳されたダイアログ)

", unsafe_allow_html=True) st.write(translate(bsd_jp))