Spaces:
Build error
Build error
File size: 9,399 Bytes
45ee012 3b3c852 b0e4f45 512f2de 46ad3c2 4758881 3b3c852 b0e4f45 33d813d 41bf6ed 33d813d 3b3c852 b0e4f45 2734d11 b0e4f45 0e62360 2734d11 0e62360 3b3c852 04376ef 2734d11 a4a1c61 2734d11 203bb0f 2734d11 203bb0f 2734d11 203bb0f 3b3c852 2734d11 0e62360 b9d05c0 0e62360 b0e4f45 414bc96 3b3c852 414bc96 512f2de 66f9f66 3b3c852 66f9f66 c1ca766 515be2e c1ca766 45ee012 c1ca766 515be2e 3b3c852 45ee012 c1ca766 3b3c852 66f9f66 45ee012 3b3c852 45ee012 c1ca766 515be2e b0e4f45 512f2de b0e4f45 512f2de b0e4f45 45ee012 b0e4f45 5e71278 b0e4f45 6b58ffb 04376ef b9d05c0 04376ef b0e4f45 6b58ffb a17d0ff 6b58ffb f790556 b0e4f45 0e62360 3b3c852 04376ef f790556 6b58ffb f790556 d0de0d8 3b3c852 d0de0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
from copy import deepcopy
from langchain.callbacks import StreamlitCallbackHandler
import streamlit as st
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
import numpy as np
import weaviate
from weaviate.embedded import EmbeddedOptions
from weaviate import Client
from weaviate.util import generate_uuid5
import logging
# Initialize session state attributes
if "debug" not in st.session_state:
st.session_state.debug = False
st_callback = StreamlitCallbackHandler(st.container())
class StreamlitCallbackHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
st.write(log_entry)
# Initialize TAPAS model and tokenizer
#tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
#model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
# Initialize Weaviate client for the embedded instance
#client = weaviate.Client(
# embedded_options=EmbeddedOptions()
#)
# Global list to store debugging information
DEBUG_LOGS = []
def log_debug_info(message):
if st.session_state.debug:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Check if StreamlitCallbackHandler is already added to avoid duplicate logs
if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers):
handler = StreamlitCallbackHandler()
logger.addHandler(handler)
logger.debug(message)
# Function to check if a class already exists in Weaviate
#def class_exists(class_name):
# try:
# client.schema.get_class(class_name)
# return True
# except:
# return False
#def map_dtype_to_weaviate(dtype):
## """
# Map pandas data types to Weaviate data types.
# """
# if "int" in str(dtype):
# return "int"
# elif "float" in str(dtype):
# return "number"
# elif "bool" in str(dtype):
# return "boolean"
# else:
# return "string"
# def ingest_data_to_weaviate(dataframe, class_name, class_description):
# # Create class schema
# class_schema = {
# "class": class_name,
# "description": class_description,
# "properties": [] # Start with an empty properties list
# }
#
# # Try to create the class without properties first
# try:
# client.schema.create({"classes": [class_schema]})
# except weaviate.exceptions.SchemaValidationException:
# # Class might already exist, so we can continue
# pass#
# # Now, let's add properties to the class
# for column_name, data_type in zip(dataframe.columns, dataframe.dtypes):
# property_schema = {
# "name": column_name,
# "description": f"Property for {column_name}",
# "dataType": [map_dtype_to_weaviate(data_type)]
# }
# try:
# client.schema.property.create(class_name, property_schema)
# except weaviate.exceptions.SchemaValidationException:
# # Property might already exist, so we can continue
# pass
#
# # Ingest data
# for index, row in dataframe.iterrows():
# obj = {
# "class": class_name,
# "id": str(index),
# "properties": row.to_dict()
# }
# client.data_object.create(obj)
# Log data ingestion
# log_debug_info(f"Data ingested into Weaviate for class: {class_name}")
def query_weaviate(question):
# This is a basic example; adapt the query based on the question
results = client.query.get(class_name).with_near_text(question).do()
return results
def ask_llm_chunk(chunk, questions):
chunk = chunk.astype(str)
try:
inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
except Exception as e:
log_debug_info(f"Tokenization error: {e}")
st.write(f"An error occurred: {e}")
return ["Error occurred while tokenizing"] * len(questions)
if inputs["input_ids"].shape[1] > 512:
log_debug_info("Token limit exceeded for chunk")
st.warning("Token limit exceeded for chunk")
return ["Token limit exceeded for chunk"] * len(questions)
outputs = model(**inputs)
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
inputs,
outputs.logits.detach(),
outputs.logits_aggregation.detach()
)
answers = []
for coordinates in predicted_answer_coordinates:
if len(coordinates) == 1:
row, col = coordinates[0]
try:
value = chunk.iloc[row, col]
log_debug_info(f"Accessed value for row {row}, col {col}: {value}")
answers.append(value)
except Exception as e:
log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
st.write(f"An error occurred: {e}")
else:
cell_values = []
for coordinate in coordinates:
row, col = coordinate
try:
value = chunk.iloc[row, col]
cell_values.append(value)
except Exception as e:
log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
st.write(f"An error occurred: {e}")
answers.append(", ".join(map(str, cell_values)))
return answers
MAX_ROWS_PER_CHUNK = 200
def summarize_map_reduce(data, questions):
dataframe = pd.read_csv(StringIO(data))
num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
all_answers = []
for chunk in dataframe_chunks:
chunk_answers = ask_llm_chunk(chunk, questions)
all_answers.extend(chunk_answers)
return all_answers
def get_class_schema(class_name):
"""
Get the schema for a specific class.
"""
all_classes = client.schema.get()["classes"]
for cls in all_classes:
if cls["class"] == class_name:
return cls
return None
st.title("TAPAS Table Question Answering with Weaviate")
# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options)
if selected_class == "New Class":
class_name = st.text_input("Enter the new class name:")
class_description = st.text_input("Enter a description for the class:")
else:
class_name = selected_class
class_description = "" # We can fetch the description from Weaviate if needed
# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
# Display the schema if an existing class is selected
class_schema = None # Initialize class_schema to None
if selected_class != "New Class":
st.write(f"Schema for {selected_class}:")
class_schema = get_class_schema(selected_class)
if class_schema:
properties = class_schema["properties"]
schema_df = pd.DataFrame(properties)
st.table(schema_df[["name", "dataType"]]) # Display only the name and dataType columns
# Before ingesting data into Weaviate, check if CSV columns match the class schema
if csv_file is not None:
data = csv_file.read().decode("utf-8")
dataframe = pd.read_csv(StringIO(data))
# Log CSV upload information
log_debug_info(f"CSV uploaded with shape: {dataframe.shape}")
# Display the uploaded CSV data
st.write("Uploaded CSV Data:")
st.write(dataframe)
# Check if columns match
if class_schema: # Ensure class_schema is not None
schema_columns = [prop["name"] for prop in class_schema["properties"]]
if set(dataframe.columns) != set(schema_columns):
st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
else:
# Ingest data into Weaviate
ingest_data_to_weaviate(dataframe, class_name, class_description)
# Input for questions
questions = st.text_area("Enter your questions (one per line)")
questions = questions.split("\n") # split questions by line
questions = [q for q in questions if q] # remove empty strings
if st.button("Submit"):
if data and questions:
answers = summarize_map_reduce(data, questions)
st.write("Answers:")
for q, a in zip(questions, answers):
st.write(f"Question: {q}")
st.write(f"Answer: {a}")
# Display debugging information
if st.checkbox("Show Debugging Information"):
st.write("Debugging Logs:")
for log in DEBUG_LOGS:
st.write(log)
# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
<script>
document.addEventListener("DOMContentLoaded", function(event) {
document.addEventListener("keydown", function(event) {
if (event.ctrlKey && event.key === "Enter") {
document.querySelector(".stButton button").click();
}
});
});
</script>
""", unsafe_allow_html=True) |