LIDAwStreamlit / app.py
djangomango's picture
Update app.py
df8dcb1
raw history blame
No virus
2.56 kB
import streamlit as st
from lida import Manager, TextGenerationConfig, llm
import os
import io
from PIL import Image
from io import BytesIO
import base64
def base64_to_image(base64_string):
byte_data = base64.b64decode(base64_string)
return Image.open(BytesIO(byte_data))
#from LIDA github
text_gen = llm(model="uukuguy/speechless-llama2-hermes-orca-platypus-13b", device_map="auto")
lida = Manager(text_gen=text_gen)
textgen_config = TextGenerationConfig(n=1, temperature=0.1, max_tokens=512)
menu = st.sidebar.selectbox("Summary or Query", ["Summary", "Query"])
if menu == "Summary":
st.subheader("Summarization of the Data")
file_uploader = st.file_uploader("Upload your file", type="csv")
if file_uploader is not None:
path_to_save = "filename.csv"
with open(path_to_save, "wb") as f:
f.write(file_uploader.getvalue())
summary = lida.summarize("filename.csv", summary_method="default")
st.write(summary)
goals = lida.goals(summary, n=2, textgen_config=textgen_config)
for goal in goals:
st.write(goal)
i = 0
library = "seaborn"
text_gen_config = TextGenerationConfig(n=1, temperature=0.1, use_cache=True)
charts = lida.visualize(summary=summary, goal=goals[i], textgen_config=textgen_config, library=library)
img_base64_string = charts[0].raster
img = base64_to_image(img_base64_string)
st.image(img)
elif menu == "Query":
st.subheader("Questioning of the Data")
file_uploader = st.file_uploader("Upload your file", type="csv")
if file_uploader is not None:
path_to_save = "filename1.csv"
with open(path_to_save, "wb") as f:
f.write(file_uploader.getvalue())
text_area = st.text_area("Query your data to generate visualization", height=200)
if st.button("Generate Graph"):
if len(text_area) > 0:
# st.info("Your query " + text_area)
# lida = Manager(text_gen=text_gen)
# text_gen_config = TextGenerationConfig(n=1, temperature=0.1, use_cache=True)
summary = lida.summarize("filename1.csv", summary_method="default", textgen_config=textgen_config)
user_query = text_area
charts = lida.visualize(summary=summary, goal=user_query, textgen_config=textgen_config)
img_base64_string = charts[0].raster
img = base64_to_image(img_base64_string)
st.image(img)
charts[0]