File size: 2,564 Bytes
4b30b1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
from lida import Manager, TextGenerationConfig, llm
import os
import io
from PIL import Image
from io import BytesIO
import base64


def base64_to_image(base64_string):

    byte_data = base64.b64decode(base64_string)

    return Image.open(BytesIO(byte_data))



#from LIDA github
text_gen = llm(model="uukuguy/speechless-llama2-hermes-orca-platypus-13b", device_map="auto")

lida = Manager(text_gen=text_gen)




textgen_config = TextGenerationConfig(n=1, temperature=0.1, max_tokens=512)


menu = st.sidebar.selectbox("Summary or Query", ["Summary", "Query"])

if menu == "Summary":
    st.subheader("Summarization of the Data")
    file_uploader = st.file_uploader("Upload your file", type="csv")
    if file_uploader is not None:
        path_to_save = "filename.csv"
        with open(path_to_save, "wb") as f:
            f.write(file_uploader.getvalue())
        summary = lida.summarize("filename.csv", summary_method="default")
        st.write(summary)
        goals = lida.goals(summary, n=2, textgen_config=textgen_config)
        for goal in goals:
            st.write(goal)
        i = 0
        library = "seaborn"
        text_gen_config = TextGenerationConfig(n=1, temperature=0.1, use_cache=True)
        charts = lida.visualize(summary=summary, goal=goals[i], textgen_config=textgen_config, library=library)
        img_base64_string = charts[0].raster
        img = base64_to_image(img_base64_string)
        st.image(img)

elif menu == "Query":
    st.subheader("Questioning of the Data")
    file_uploader = st.file_uploader("Upload your file", type="csv")
    if file_uploader is not None:
        path_to_save = "filename1.csv"
        with open(path_to_save, "wb") as f:
            f.write(file_uploader.getvalue())

        text_area = st.text_area("Query your data to generate visualization", height=200)
        if st.button("Generate Graph"):
            if len(text_area) > 0:
                # st.info("Your query " + text_area)
                # lida = Manager(text_gen=text_gen)
                # text_gen_config = TextGenerationConfig(n=1, temperature=0.1, use_cache=True)
                summary = lida.summarize("filename1.csv", summary_method="default", textgen_config=textgen_config)
                user_query = text_area
                charts = lida.visualize(summary=summary, goal=user_query, textgen_config=textgen_config)
                img_base64_string = charts[0].raster
                img = base64_to_image(img_base64_string)
                st.image(img)
                charts[0]