File size: 9,399 Bytes
45ee012
3b3c852
b0e4f45
 
 
 
512f2de
46ad3c2
 
4758881
 
3b3c852
b0e4f45
33d813d
 
 
 
41bf6ed
33d813d
3b3c852
 
 
 
 
b0e4f45
2734d11
 
b0e4f45
0e62360
2734d11
 
 
0e62360
3b3c852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04376ef
2734d11
 
 
 
 
 
a4a1c61
2734d11
 
 
 
 
 
 
 
 
 
 
 
203bb0f
2734d11
 
 
 
 
 
 
 
 
 
 
 
 
 
203bb0f
2734d11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203bb0f
3b3c852
2734d11
0e62360
 
 
b9d05c0
0e62360
 
b0e4f45
 
414bc96
 
 
3b3c852
414bc96
 
512f2de
66f9f66
3b3c852
66f9f66
 
c1ca766
 
 
 
 
 
 
515be2e
c1ca766
 
45ee012
 
c1ca766
515be2e
3b3c852
45ee012
c1ca766
3b3c852
66f9f66
45ee012
 
 
 
 
 
 
 
3b3c852
45ee012
 
c1ca766
515be2e
b0e4f45
512f2de
b0e4f45
 
512f2de
b0e4f45
45ee012
b0e4f45
 
 
 
5e71278
b0e4f45
6b58ffb
 
 
 
 
 
 
 
 
 
04376ef
 
 
 
 
 
b9d05c0
04376ef
 
 
 
 
 
b0e4f45
 
 
6b58ffb
 
a17d0ff
6b58ffb
 
 
 
 
 
 
 
 
f790556
b0e4f45
0e62360
3b3c852
 
 
04376ef
f790556
 
 
 
6b58ffb
f790556
 
 
 
 
 
 
d0de0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
3b3c852
 
 
 
 
 
d0de0d8
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from copy import deepcopy
from langchain.callbacks import StreamlitCallbackHandler
import streamlit as st
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
import numpy as np
import weaviate
from weaviate.embedded import EmbeddedOptions
from weaviate import Client
from weaviate.util import generate_uuid5
import logging

# Initialize session state attributes
if "debug" not in st.session_state:
    st.session_state.debug = False
    
st_callback = StreamlitCallbackHandler(st.container())

class StreamlitCallbackHandler(logging.Handler):
    def emit(self, record):
        log_entry = self.format(record)
        st.write(log_entry)
        
# Initialize TAPAS model and tokenizer
#tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
#model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")

# Initialize Weaviate client for the embedded instance
#client = weaviate.Client(
#  embedded_options=EmbeddedOptions()
#)

# Global list to store debugging information
DEBUG_LOGS = []

def log_debug_info(message):
    if st.session_state.debug:
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.DEBUG)
        
        # Check if StreamlitCallbackHandler is already added to avoid duplicate logs
        if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers):
            handler = StreamlitCallbackHandler()
            logger.addHandler(handler)
        
        logger.debug(message)


# Function to check if a class already exists in Weaviate
#def class_exists(class_name):
#    try:
#        client.schema.get_class(class_name)
#        return True
#    except:
#        return False

#def map_dtype_to_weaviate(dtype):
##    """
 #   Map pandas data types to Weaviate data types.
 #   """
 #   if "int" in str(dtype):
 #       return "int"
 #   elif "float" in str(dtype):
 #       return "number"
 #   elif "bool" in str(dtype):
 #       return "boolean"
 #   else:
 #       return "string"

# def ingest_data_to_weaviate(dataframe, class_name, class_description):
#    # Create class schema
#    class_schema = {
#        "class": class_name,
#        "description": class_description,
#        "properties": []  # Start with an empty properties list
#    }
#    
#    # Try to create the class without properties first
 #   try:
#        client.schema.create({"classes": [class_schema]})
#    except weaviate.exceptions.SchemaValidationException:
#        # Class might already exist, so we can continue
#        pass#

#    # Now, let's add properties to the class
#    for column_name, data_type in zip(dataframe.columns, dataframe.dtypes):
#        property_schema = {
#            "name": column_name,
#            "description": f"Property for {column_name}",
#            "dataType": [map_dtype_to_weaviate(data_type)]
#        }
#        try:
#            client.schema.property.create(class_name, property_schema)
#        except weaviate.exceptions.SchemaValidationException:
#            # Property might already exist, so we can continue
#            pass
#
#    # Ingest data
#    for index, row in dataframe.iterrows():
#        obj = {
#            "class": class_name,
#            "id": str(index),
#            "properties": row.to_dict()
#        }
#        client.data_object.create(obj)

    # Log data ingestion
#    log_debug_info(f"Data ingested into Weaviate for class: {class_name}")

def query_weaviate(question):
    # This is a basic example; adapt the query based on the question
    results = client.query.get(class_name).with_near_text(question).do()
    return results

def ask_llm_chunk(chunk, questions):
    chunk = chunk.astype(str)
    try:
        inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
    except Exception as e:
        log_debug_info(f"Tokenization error: {e}")
        st.write(f"An error occurred: {e}")
        return ["Error occurred while tokenizing"] * len(questions)

    if inputs["input_ids"].shape[1] > 512:
        log_debug_info("Token limit exceeded for chunk")
        st.warning("Token limit exceeded for chunk")
        return ["Token limit exceeded for chunk"] * len(questions)

    outputs = model(**inputs)
    predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
        inputs,
        outputs.logits.detach(),
        outputs.logits_aggregation.detach()
    )

    answers = []
    for coordinates in predicted_answer_coordinates:
        if len(coordinates) == 1:
            row, col = coordinates[0]
            try:
                value = chunk.iloc[row, col]
                log_debug_info(f"Accessed value for row {row}, col {col}: {value}")
                answers.append(value)
            except Exception as e:
                log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
                st.write(f"An error occurred: {e}")
        else:
            cell_values = []
            for coordinate in coordinates:
                row, col = coordinate
                try:
                    value = chunk.iloc[row, col]
                    cell_values.append(value)
                except Exception as e:
                    log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
                    st.write(f"An error occurred: {e}")
            answers.append(", ".join(map(str, cell_values)))

    return answers

MAX_ROWS_PER_CHUNK = 200

def summarize_map_reduce(data, questions):
    dataframe = pd.read_csv(StringIO(data))
    num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
    dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
    all_answers = []
    for chunk in dataframe_chunks:
        chunk_answers = ask_llm_chunk(chunk, questions)
        all_answers.extend(chunk_answers)
    return all_answers

def get_class_schema(class_name):
    """
    Get the schema for a specific class.
    """
    all_classes = client.schema.get()["classes"]
    for cls in all_classes:
        if cls["class"] == class_name:
            return cls
    return None

st.title("TAPAS Table Question Answering with Weaviate")

# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options)

if selected_class == "New Class":
    class_name = st.text_input("Enter the new class name:")
    class_description = st.text_input("Enter a description for the class:")
else:
    class_name = selected_class
    class_description = ""  # We can fetch the description from Weaviate if needed

# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])

# Display the schema if an existing class is selected
class_schema = None  # Initialize class_schema to None
if selected_class != "New Class":
    st.write(f"Schema for {selected_class}:")
    class_schema = get_class_schema(selected_class)
    if class_schema:
        properties = class_schema["properties"]
        schema_df = pd.DataFrame(properties)
        st.table(schema_df[["name", "dataType"]])  # Display only the name and dataType columns

# Before ingesting data into Weaviate, check if CSV columns match the class schema
if csv_file is not None:
    data = csv_file.read().decode("utf-8")
    dataframe = pd.read_csv(StringIO(data))

    # Log CSV upload information
    log_debug_info(f"CSV uploaded with shape: {dataframe.shape}")
    
    # Display the uploaded CSV data
    st.write("Uploaded CSV Data:")
    st.write(dataframe)

    # Check if columns match
    if class_schema:  # Ensure class_schema is not None
        schema_columns = [prop["name"] for prop in class_schema["properties"]]
        if set(dataframe.columns) != set(schema_columns):
            st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
        else:
            # Ingest data into Weaviate
            ingest_data_to_weaviate(dataframe, class_name, class_description)

    # Input for questions
    questions = st.text_area("Enter your questions (one per line)")
    questions = questions.split("\n")  # split questions by line
    questions = [q for q in questions if q]  # remove empty strings

    if st.button("Submit"):
        if data and questions:
            answers = summarize_map_reduce(data, questions)
            st.write("Answers:")
            for q, a in zip(questions, answers):
                st.write(f"Question: {q}")
                st.write(f"Answer: {a}")

# Display debugging information
if st.checkbox("Show Debugging Information"):
    st.write("Debugging Logs:")
    for log in DEBUG_LOGS:
        st.write(log)
        
# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
    <script>
    document.addEventListener("DOMContentLoaded", function(event) {
        document.addEventListener("keydown", function(event) {
            if (event.ctrlKey && event.key === "Enter") {
                document.querySelector(".stButton button").click();
            }
        });
    });
    </script>
    """, unsafe_allow_html=True)