import streamlit as st |
import sparknlp |
import pandas as pd |
import json |
from sparknlp.base import * |
from sparknlp.annotator import * |
from pyspark.ml import Pipeline |
from sparknlp.pretrained import PretrainedPipeline |
st.set_page_config( |
layout="wide", |
initial_sidebar_state="auto" |
) |
st.markdown(""" |
<style> |
.main-title { |
font-size: 36px; |
color: #4A90E2; |
font-weight: bold; |
text-align: center; |
} |
.section { |
background-color: #f9f9f9; |
padding: 10px; |
border-radius: 10px; |
margin-top: 10px; |
} |
.section p, .section ul { |
color: #666666; |
} |
</style> |
""", unsafe_allow_html=True) |
@st.cache_resource |
def init_spark(): |
return sparknlp.start() |
@st.cache_resource |
def create_pipeline(model): |
document_assembler = MultiDocumentAssembler() \ |
.setInputCols("table_json", "questions") \ |
.setOutputCols("document_table", "document_questions") |
sentence_detector = SentenceDetector() \ |
.setInputCols(["document_questions"]) \ |
.setOutputCol("questions") |
table_assembler = TableAssembler()\ |
.setInputCols(["document_table"])\ |
.setOutputCol("table") |
tapas_wtq = TapasForQuestionAnswering\ |
.pretrained("table_qa_tapas_base_finetuned_wtq", "en")\ |
.setInputCols(["questions", "table"])\ |
.setOutputCol("answers_wtq") |
tapas_sqa = TapasForQuestionAnswering\ |
.pretrained("table_qa_tapas_base_finetuned_sqa", "en")\ |
.setInputCols(["questions", "table"])\ |
.setOutputCol("answers_sqa") |
pipeline = Pipeline(stages=[document_assembler, sentence_detector, table_assembler, tapas_wtq, tapas_sqa]) |
return pipeline |
def fit_data(pipeline, json_data, question): |
spark_df = spark.createDataFrame([[json_data, question]]).toDF("table_json", "questions") |
model = pipeline.fit(spark_df) |
res = model.transform(spark_df) |
return res.select("answers_wtq.result", "answers_sqa.result").collect() |
model = st.sidebar.selectbox( |
"Choose the pretrained model", |
["table_qa_tapas_base_finetuned_wtq", "table_qa_tapas_base_finetuned_sqa"], |
help="For more info about the models visit: https://sparknlp.org/models" |
) |
title = 'TAPAS for Table-Based Question Answering with Spark NLP' |
sub_title = (""" |
TAPAS (Table Parsing Supervised via Pre-trained Language Models) enhances the BERT architecture to effectively process tabular data, allowing it to answer complex questions about tables without needing to convert them into text.<br> |
<br> |
<strong>table_qa_tapas_base_finetuned_wtq:</strong> This model excels at answering questions that require aggregating data across the entire table, such as calculating sums or averages.<br> |
<strong>table_qa_tapas_base_finetuned_sqa:</strong> This model is designed for sequential question-answering tasks where the answer to each question may depend on the context provided by previous answers. |
""") |
st.markdown(f'<div class="main-title">{title}</div>', unsafe_allow_html=True) |
st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True) |
link = """ |
<a href="https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/streamlit_notebooks/NER_HINDI_ENGLISH.ipynb"> |
<img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/> |
</a> |
""" |
st.sidebar.markdown('Reference notebook:') |
st.sidebar.markdown(link, unsafe_allow_html=True) |
json_data = ''' |
{ |
"header": ["name", "net_worth", "age", "nationality", "company", "industry"], |
"rows": [ |
["Elon Musk", "$200,000,000,000", "52", "American", "Tesla, SpaceX", "Automotive, Aerospace"], |
["Jeff Bezos", "$150,000,000,000", "60", "American", "Amazon", "E-commerce"], |
["Bernard Arnault", "$210,000,000,000", "74", "French", "LVMH", "Luxury Goods"], |
["Bill Gates", "$120,000,000,000", "68", "American", "Microsoft", "Technology"], |
["Warren Buffett", "$110,000,000,000", "93", "American", "Berkshire Hathaway", "Conglomerate"], |
["Larry Page", "$100,000,000,000", "51", "American", "Google", "Technology"], |
["Mark Zuckerberg", "$85,000,000,000", "40", "American", "Meta", "Social Media"], |
["Mukesh Ambani", "$80,000,000,000", "67", "Indian", "Reliance Industries", "Conglomerate"], |
["Alice Walton", "$65,000,000,000", "74", "American", "Walmart", "Retail"], |
["Francoise Bettencourt Meyers", "$70,000,000,000", "70", "French", "L'Oreal", "Cosmetics"], |
["Amancio Ortega", "$75,000,000,000", "88", "Spanish", "Inditex (Zara)", "Retail"], |
["Carlos Slim", "$55,000,000,000", "84", "Mexican", "America Movil", "Telecom"] |
] |
} |
''' |
queries = [ |
"Who has a higher net worth, Bernard Arnault or Jeff Bezos?", |
"List the top three individuals by net worth.", |
"Who is the richest person in the technology industry?", |
"Which company in the e-commerce industry has the highest net worth?", |
"Who is the oldest billionaire on the list?", |
"Which individual under the age of 60 has the highest net worth?", |
"Who is the wealthiest American, and which company do they own?", |
"Find all French billionaires and list their companies.", |
"How many women are on the list, and what are their total net worths?", |
"Who is the wealthiest non-American on the list?", |
"Find the person who is the youngest and has a net worth over $100 billion.", |
"Who owns companies in more than one industry, and what are those industries?", |
"What is the total net worth of all individuals over 70?", |
"How many billionaires are in the conglomerate industry?" |
] |
table_data = json.loads(json_data) |
df_table = pd.DataFrame(table_data["rows"], columns=table_data["header"]) |
df_table.index += 1 |
st.write("") |
st.write("Context DataFrame (Click To Edit)") |
edited_df = st.data_editor(df_table) |
table_json_data = { |
"header": edited_df.columns.tolist(), |
"rows": edited_df.values.tolist() |
} |
table_json_str = json.dumps(table_json_data) |
selected_text = st.selectbox("Question Query", queries) |
custom_input = st.text_input("Try it with your own Question!") |
text_to_analyze = custom_input if custom_input else selected_text |
spark = init_spark() |
pipeline = create_pipeline(model) |
output = fit_data(pipeline, table_json_str, text_to_analyze) |
st.markdown("---") |
st.subheader("Processed Output") |
if output: |
results_wtq = output[0][0] if output[0][0] else "No results found." |
results_sqa = output[0][1] if output[0][1] else "No results found." |
st.markdown(f"**Answers from WTQ model:** {', '.join(results_wtq)}") |
st.markdown(f"**Answers from SQA model:** {', '.join(results_sqa)}") |