|
import streamlit as st |
|
import sparknlp |
|
import pandas as pd |
|
import json |
|
|
|
from sparknlp.base import * |
|
from sparknlp.annotator import * |
|
from pyspark.ml import Pipeline |
|
from sparknlp.pretrained import PretrainedPipeline |
|
|
|
|
|
st.set_page_config( |
|
layout="wide", |
|
initial_sidebar_state="auto" |
|
) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main-title { |
|
font-size: 36px; |
|
color: #4A90E2; |
|
font-weight: bold; |
|
text-align: center; |
|
} |
|
.section { |
|
background-color: #f9f9f9; |
|
padding: 10px; |
|
border-radius: 10px; |
|
margin-top: 10px; |
|
} |
|
.section p, .section ul { |
|
color: #666666; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def init_spark(): |
|
return sparknlp.start() |
|
|
|
@st.cache_resource |
|
def create_pipeline(model): |
|
document_assembler = MultiDocumentAssembler() \ |
|
.setInputCols("table_json", "questions") \ |
|
.setOutputCols("document_table", "document_questions") |
|
|
|
sentence_detector = SentenceDetector() \ |
|
.setInputCols(["document_questions"]) \ |
|
.setOutputCol("questions") |
|
|
|
table_assembler = TableAssembler()\ |
|
.setInputCols(["document_table"])\ |
|
.setOutputCol("table") |
|
|
|
tapas_wtq = TapasForQuestionAnswering\ |
|
.pretrained("table_qa_tapas_base_finetuned_wtq", "en")\ |
|
.setInputCols(["questions", "table"])\ |
|
.setOutputCol("answers_wtq") |
|
|
|
tapas_sqa = TapasForQuestionAnswering\ |
|
.pretrained("table_qa_tapas_base_finetuned_sqa", "en")\ |
|
.setInputCols(["questions", "table"])\ |
|
.setOutputCol("answers_sqa") |
|
|
|
pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas_wtq, tapas_sqa]) |
|
return pipeline |
|
|
|
def fit_data(pipeline, json_data, question): |
|
spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions") |
|
model = pipeline.fit(spark_df) |
|
res = model.transform(spark_df) |
|
return res.select("answers_wtq.result", "answers_sqa.result").collect() |
|
|
|
|
|
model = st.sidebar.selectbox( |
|
"Choose the pretrained model", |
|
["table_qa_tapas_base_finetuned_wtq", "table_qa_tapas_base_finetuned_sqa"], |
|
help="For more info about the models visit: https://sparknlp.org/models" |
|
) |
|
|
|
|
|
title = 'TAPAS for Table-Based Question Answering with Spark NLP' |
|
sub_title = (""" |
|
TAPAS (Table Parsing Supervised via Pre-trained Language Models) enhances the BERT architecture to effectively process tabular data, allowing it to answer complex questions about tables without needing to convert them into text.<br> |
|
<br> |
|
<strong>table_qa_tapas_base_finetuned_wtq:</strong> This model excels at answering questions that require aggregating data across the entire table, such as calculating sums or averages.<br> |
|
<strong>table_qa_tapas_base_finetuned_sqa:</strong> This model is designed for sequential question-answering tasks where the answer to each question may depend on the context provided by previous answers. |
|
""") |
|
|
|
st.markdown(f'<div class="main-title">{title}</div>', unsafe_allow_html=True) |
|
st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True) |
|
|
|
|
|
link = """ |
|
<a href="https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/streamlit_notebooks/NER_HINDI_ENGLISH.ipynb"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/> |
|
</a> |
|
""" |
|
st.sidebar.markdown('Reference notebook:') |
|
st.sidebar.markdown(link, unsafe_allow_html=True) |
|
|
|
|
|
|
|
json_data = ''' |
|
{ |
|
"header": ["name", "net_worth", "age", "nationality", "company", "industry"], |
|
"rows": [ |
|
["Elon Musk", "$200,000,000,000", "52", "American", "Tesla, SpaceX", "Automotive, Aerospace"], |
|
["Jeff Bezos", "$150,000,000,000", "60", "American", "Amazon", "E-commerce"], |
|
["Bernard Arnault", "$210,000,000,000", "74", "French", "LVMH", "Luxury Goods"], |
|
["Bill Gates", "$120,000,000,000", "68", "American", "Microsoft", "Technology"], |
|
["Warren Buffett", "$110,000,000,000", "93", "American", "Berkshire Hathaway", "Conglomerate"], |
|
["Larry Page", "$100,000,000,000", "51", "American", "Google", "Technology"], |
|
["Mark Zuckerberg", "$85,000,000,000", "40", "American", "Meta", "Social Media"], |
|
["Mukesh Ambani", "$80,000,000,000", "67", "Indian", "Reliance Industries", "Conglomerate"], |
|
["Alice Walton", "$65,000,000,000", "74", "American", "Walmart", "Retail"], |
|
["Francoise Bettencourt Meyers", "$70,000,000,000", "70", "French", "L'Oreal", "Cosmetics"], |
|
["Amancio Ortega", "$75,000,000,000", "88", "Spanish", "Inditex (Zara)", "Retail"], |
|
["Carlos Slim", "$55,000,000,000", "84", "Mexican", "America Movil", "Telecom"] |
|
] |
|
} |
|
''' |
|
|
|
|
|
queries = [ |
|
"Who has a higher net worth, Bernard Arnault or Jeff Bezos?", |
|
"List the top three individuals by net worth.", |
|
"Who is the richest person in the technology industry?", |
|
"Which company in the e-commerce industry has the highest net worth?", |
|
"Who is the oldest billionaire on the list?", |
|
"Which individual under the age of 60 has the highest net worth?", |
|
"Who is the wealthiest American, and which company do they own?", |
|
"Find all French billionaires and list their companies.", |
|
"How many women are on the list, and what are their total net worths?", |
|
"Who is the wealthiest non-American on the list?", |
|
"Find the person who is the youngest and has a net worth over $100 billion.", |
|
"Who owns companies in more than one industry, and what are those industries?", |
|
"What is the total net worth of all individuals over 70?", |
|
"How many billionaires are in the conglomerate industry?" |
|
] |
|
|
|
|
|
table_data = json.loads(json_data) |
|
df_table = pd.DataFrame(table_data["rows"], columns=table_data["header"]) |
|
df_table.index += 1 |
|
|
|
st.write("") |
|
st.write("Context DataFrame (Click To Edit)") |
|
edited_df = st.data_editor(df_table) |
|
|
|
|
|
table_json_data = { |
|
"header": edited_df.columns.tolist(), |
|
"rows": edited_df.values.tolist() |
|
} |
|
table_json_str = json.dumps(table_json_data) |
|
|
|
|
|
selected_text = st.selectbox("Question Query", queries) |
|
custom_input = st.text_input("Try it with your own Question!") |
|
text_to_analyze = custom_input if custom_input else selected_text |
|
|
|
|
|
spark = init_spark() |
|
pipeline = create_pipeline(model) |
|
|
|
|
|
output = fit_data(pipeline, table_json_str, text_to_analyze) |
|
|
|
|
|
st.markdown("---") |
|
st.subheader("Processed Output") |
|
|
|
|
|
if output: |
|
|
|
results_wtq = output[0][0] if output[0][0] else "No results found." |
|
results_sqa = output[0][1] if output[0][1] else "No results found." |
|
st.markdown(f"**Answers from WTQ model:** {', '.join(results_wtq)}") |
|
st.markdown(f"**Answers from SQA model:** {', '.join(results_sqa)}") |