Spaces:
Running
Running
import streamlit as st | |
from lida import Manager, TextGenerationConfig, llm | |
from dotenv import load_dotenv | |
import os | |
import io | |
from PIL import Image | |
from io import BytesIO | |
import base64 | |
HUGGINGFACE_API_TOKEN = os.environ.get('HF_TOKEN') | |
def base64_to_image(base64_string): | |
byte_data = base64.b64decode(base64_string) | |
return Image.open(BytesIO(byte_data)) | |
#from LIDA github | |
text_gen = llm(model="uukuguy/speechless-llama2-hermes-orca-platypus-13b", device_map="auto") | |
lida = Manager(text_gen=text_gen) | |
textgen_config = TextGenerationConfig(n=1, temperature=0.1, max_tokens=512) | |
menu = st.sidebar.selectbox("Summary or Query", ["Summary", "Query"]) | |
if menu == "Summary": | |
st.subheader("Summarization of the Data") | |
file_uploader = st.file_uploader("Upload your file", type="csv") | |
if file_uploader is not None: | |
path_to_save = "filename.csv" | |
with open(path_to_save, "wb") as f: | |
f.write(file_uploader.getvalue()) | |
summary = lida.summarize("filename.csv", summary_method="default") | |
st.write(summary) | |
goals = lida.goals(summary, n=2, textgen_config=textgen_config) | |
for goal in goals: | |
st.write(goal) | |
i = 0 | |
library = "seaborn" | |
text_gen_config = TextGenerationConfig(n=1, temperature=0.1, use_cache=True) | |
charts = lida.visualize(summary=summary, goal=goals[i], textgen_config=textgen_config, library=library) | |
img_base64_string = charts[0].raster | |
img = base64_to_image(img_base64_string) | |
st.image(img) | |
elif menu == "Query": | |
st.subheader("Questioning of the Data") | |
file_uploader = st.file_uploader("Upload your file", type="csv") | |
if file_uploader is not None: | |
path_to_save = "filename1.csv" | |
with open(path_to_save, "wb") as f: | |
f.write(file_uploader.getvalue()) | |
text_area = st.text_area("Query your data to generate visualization", height=200) | |
if st.button("Generate Graph"): | |
if len(text_area) > 0: | |
# st.info("Your query " + text_area) | |
# lida = Manager(text_gen=text_gen) | |
# text_gen_config = TextGenerationConfig(n=1, temperature=0.1, use_cache=True) | |
summary = lida.summarize("filename1.csv", summary_method="default", textgen_config=textgen_config) | |
user_query = text_area | |
charts = lida.visualize(summary=summary, goal=user_query, textgen_config=textgen_config) | |
img_base64_string = charts[0].raster | |
img = base64_to_image(img_base64_string) | |
st.image(img) | |
charts[0] | |