Sahithi-07 commited on
Commit
37599f2
1 Parent(s): d3f5c9c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from transformers import AutoTokenizer, BartForConditionalGeneration
4
+
5
+ # Load the TAPEX tokenizer and model (replace with your fine-tuned model names)
6
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
7
+ model = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq")
8
+
9
+
10
+ def predict(table_path, query):
11
+ """
12
+ Predicts answer to a question using the TAPEX model on a given table.
13
+
14
+ Args:
15
+ table_path: Path to the CSV file containing the table data.
16
+ query: The question to be answered.
17
+
18
+ Returns:
19
+ The predicted answer as a string.
20
+ """
21
+ # Load the sales data from CSV
22
+ sales_record = pd.read_csv(r"C:/Users/sahit/Downloads/LLm of chatbot/10000 Sales Records.csv")
23
+ sales_record = sales_record.astype(str) # Ensure string type for tokenizer
24
+
25
+ # Truncate the input to fit within the model's maximum sequence length
26
+ max_length = model.config.max_position_embeddings
27
+ encoding = tokenizer(table=sales_record, query=query, return_tensors="pt", truncation=True, max_length=max_length)
28
+
29
+ # Generate the output
30
+ outputs = model.generate(**encoding)
31
+
32
+ # Decode the output
33
+ prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
34
+ return prediction
35
+
36
+ st.title("Chatbot with CSV using TAPEX")
37
+
38
+ # Upload table data
39
+ uploaded_file = st.file_uploader("Upload Sales Data (CSV)", type="csv")
40
+
41
+ if uploaded_file is not None:
42
+ # Read the uploaded CSV file
43
+ df = pd.read_csv(uploaded_file)
44
+ st.write(df) # Display the uploaded table
45
+
46
+ # User query input
47
+ query = st.text_input("Hello ! Ask me anything about " + uploaded_file.name + " 🤗")
48
+
49
+ if query:
50
+ # Predict answer using the model
51
+ prediction = predict(uploaded_file.name, query)
52
+ st.write(f"*Your Question:* {query}")
53
+ st.write(f"*Predicted Answer:* {prediction}")
54
+ else:
55
+ st.info("Please upload a CSV file containing sales data.")