t5 / Demo.py
abdullahmubeen10's picture
Upload 177 files
dcdb825 verified
raw
history blame
7.26 kB
import streamlit as st
import sparknlp
import os
import pandas as pd
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
from sparknlp.pretrained import PretrainedPipeline
# Page Configuration
st.set_page_config(
layout="wide",
initial_sidebar_state="auto"
)
# Custom CSS for Styling
st.markdown("""
<style>
.main-title {
font-size: 36px;
color: #4A90E2;
font-weight: bold;
text-align: center;
}
.section-content {
background-color: #f9f9f9;
padding: 10px;
border-radius: 10px;
margin-top: 10px;
}
.section-content p, .section-content ul {
color: #666666;
}
</style>
""", unsafe_allow_html=True)
# Initialize Spark Session
@st.cache_resource
def start_spark_session():
return sparknlp.start()
# Create NLP Pipeline
@st.cache_resource
def build_nlp_pipeline(model_name, task):
document_assembler = DocumentAssembler()\
.setInputCol("text")\
.setOutputCol("document")
t5_transformer = T5Transformer() \
.pretrained(model_name, 'en') \
.setTask(task)\
.setInputCols(["document"]) \
.setOutputCol("output")
pipeline = Pipeline().setStages([document_assembler, t5_transformer])
return pipeline
# Apply Pipeline to Text Data
def process_text(pipeline, text):
df = spark.createDataFrame([[text]]).toDF("text")
result = pipeline.fit(df).transform(df)
return result.select('output.result').collect()
# Model and Task Information
model_info = [
{
"model_name": "t5_small",
"title": "Multi-Task NLP Model",
"description": "The T5 model performs 18 different NLP tasks including summarization, question answering, and grammatical correctness detection."
},
{
"model_name": "t5_base",
"title": "Multi-Task NLP Model",
"description": "A larger variant of the T5 model, capable of performing a variety of NLP tasks with improved accuracy."
},
{
"model_name": "google_t5_small_ssm_nq",
"title": "Question Answering Model",
"description": "This model is fine-tuned for answering questions based on the Natural Questions dataset, leveraging pre-training on large text corpora."
}
]
task_descriptions = {
'Sentence Classification - cola': "Classify if a sentence is grammatically correct.",
'Natural Language Inference - rte': "The RTE task is defined as recognizing, given two text fragments, whether the meaning of one text can be inferred (entailed) from the other or not.",
'Natural Language Inference - mnli': "Classify for a hypothesis and premise whether they contradict or contradict each other or neither of both (3 class).",
'Natural Language Inference - qnli': "Classify whether the answer to a question can be deducted from an answer candidate.",
'Natural Language Inference - cb': "Classify for a premise and a hypothesis whether they contradict each other or not (binary).",
'Coreference Resolution - mrpc': "Classify whether a pair of sentences is a re-phrasing of each other (semantically equivalent).",
'Coreference Resolution - qqp': "Classify whether a pair of questions is a re-phrasing of each other (semantically equivalent).",
'Sentiment Analysis - sst2': "Classify the sentiment of a sentence as positive or negative.",
'Sentiment Analysis - stsb': "Measures how similar two sentences are on a scale from 0 to 5",
'Question Answering - copa': "Classify for a question, premise, and 2 choices which choice the correct choice is (binary).",
'Question Answering - multirc': "Classify for a question, a paragraph of text, and an answer candidate, if the answer is correct (binary).",
'Question Answering - squad': "Answer a question for a given context.",
'Word Sense Disambiguation - wic': "Classify for a pair of sentences and a disambiguous word if the word has the same meaning in both sentences.",
'Text - summarization': "Summarize text into a shorter representation.",
'Translation - wmt1': "This model is used to translate one language to the other language. Example: Translate English to German.",
'Translation - wmt2': "This model is used to translate one language to the other language. Example: Translate English to French.",
'Translation - wmt3': "This model is used to translate one language to the other language. Example: Translate English to Romanian."
}
# Sidebar: Task and Model Selection
selected_task = st.sidebar.selectbox("Choose an NLP Task", list(task_descriptions.keys()))
task_for_pipeline = f"{selected_task.split(' - ')[-1]}:"
available_models = ['google_t5_small_ssm_nq'] if "Question Answering" in selected_task else ['t5_base', 't5_small']
selected_model = st.sidebar.selectbox("Choose a Model", available_models)
# Get Model Info
model_details = next((info for info in model_info if info['model_name'] == selected_model), None)
app_title = model_details['title'] if model_details else "Unknown Model"
app_description = model_details['description'] if model_details else "No description available."
# Display Model Info
st.markdown(f'<div class="main-title">{app_title}</div>', unsafe_allow_html=True)
st.markdown(f'<div class="section-content"><p>{app_description}</p></div>', unsafe_allow_html=True)
st.subheader(task_descriptions[selected_task])
# Load Example Texts
example_folder = f"inputs/{selected_task}/{selected_model}"
example_texts = [
line.strip()
for file in os.listdir(example_folder)
if file.endswith('.txt')
for line in open(os.path.join(example_folder, file), 'r', encoding='utf-8')
]
# User Input: Select or Enter Text
selected_example = st.selectbox("Select an Example", example_texts)
custom_input = st.text_input("Or enter your own text:")
text_to_process = custom_input if custom_input else selected_example
# Display Selected Text
st.subheader('Selected Text')
st.markdown(f'<div class="section-content">{text_to_process}</div>', unsafe_allow_html=True)
# Sidebar: Reference Notebook
st.sidebar.markdown('Reference notebook:')
st.sidebar.markdown("""
<a href="https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/streamlit_notebooks/T5TRANSFORMER.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
</a>
""", unsafe_allow_html=True)
# Special Cases for Translation Tasks
task_for_pipeline = {
'wmt1:': 'translate English to German:',
'wmt2:': 'translate English to French:',
'wmt3:': 'translate English to Romanian:'
}.get(task_for_pipeline, task_for_pipeline)
# Initialize Spark, Build Pipeline, and Process Text
spark = start_spark_session()
nlp_pipeline = build_nlp_pipeline(selected_model, task_for_pipeline)
processed_output = process_text(nlp_pipeline, text_to_process)
# Display Processed Output
st.subheader("Processed Output")
output_text = "".join(processed_output[0][0])
st.markdown(f'<div class="section-content">{output_text}</div>', unsafe_allow_html=True)