Rohitface commited on
Commit
07959c7
·
verified ·
1 Parent(s): eb91249

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -140
app.py CHANGED
@@ -1,155 +1,139 @@
1
- # app.py
2
-
3
  import gradio as gr
4
- from transformers import pipeline
5
- from sentence_transformers import SentenceTransformer
6
- import faiss
7
- import numpy as np
8
-
9
- # --- Backend Logic ---
10
-
11
- # Step 1: Load the necessary models
12
- # OPTIMIZED: Switched to 'google/flan-t5-small' for maximum speed on free hardware.
13
- print("Loading models... This may take a moment, especially the first time.")
14
- generator = pipeline("text2text-generation", model="google/flan-t5-small")
15
- embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
16
- print("Models loaded successfully!")
17
-
18
- def chunk_text(text, chunk_size=256, overlap=32):
19
- """Splits text into overlapping chunks."""
20
- words = text.split()
21
- chunks = []
22
- for i in range(0, len(words), chunk_size - overlap):
23
- chunks.append(" ".join(words[i:i + chunk_size]))
24
- return chunks
25
-
26
- def process_chat_request(user_question, chat_history, state_data):
27
  """
28
- The main function that handles the chat logic using the RAG pipeline.
29
  """
30
- index = state_data.get("index")
31
- chunks = state_data.get("chunks")
32
-
33
- if not all([index, chunks]):
34
- raise gr.Error("File index is missing. Please restart by uploading a file.")
35
- if not user_question:
36
- raise gr.Error("Please enter a question.")
37
 
38
  try:
39
- # 1. RETRIEVE: Find the most relevant chunks
40
- question_embedding = embedder.encode([user_question])
41
- _, top_k_indices = index.search(question_embedding, k=3) # Retrieve top 3 chunks
 
 
 
 
 
42
 
43
- context = " ".join([chunks[i] for i in top_k_indices[0]])
 
44
 
45
- # 2. GENERATE: Create a prompt and get an answer
46
- prompt = f"""
47
- Based on the following context, provide a detailed answer to the user's question.
48
 
49
- CONTEXT:
50
- ---
51
- {context}
52
- ---
53
 
54
- QUESTION: {user_question}
 
 
55
 
56
- ANSWER:
57
- """
 
 
 
58
 
59
- result = generator(
60
- prompt,
61
- max_length=512,
62
- num_beams=4,
63
- temperature=0.1
64
- )
65
- bot_response = result[0]['generated_text']
66
 
67
- except Exception as e:
68
- raise gr.Error(f"An error occurred during processing: {e}")
69
-
70
- chat_history.append((user_question, bot_response))
71
- return "", chat_history
72
-
73
- # --- Gradio UI Definition ---
74
-
75
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="teal"), title="Text File Analyzer") as demo:
76
- app_state = gr.State({})
77
-
78
- with gr.Column(visible=True) as welcome_page:
79
- gr.Markdown(
80
- """
81
- <div style='text-align: center; font-family: "Garamond", serif; padding-top: 30px;'>
82
- <h1 style='font-size: 3.5em;'>Efficient Text File Analyzer</h1>
83
- <p style='font-size: 1.5em; color: #555;'>Chat with any .txt document using an efficient RAG pipeline.</p>
84
- </div>
85
- """
86
- )
87
- gr.HTML(
88
- """
89
- <div style='text-align: center; padding: 20px;'>
90
- <img src='https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExd2Vjb3M2eGZzN2FkNWZpZzZ0bWl0c2JqZzZlMHVwZ2l4b2t0eXFpcyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/YWjDA4k2n6d5Ew42zC/giphy.gif'
91
- style='max-width: 350px; margin: auto; border-radius: 20px; box-shadow: 0 8px 16px rgba(0,0,0,0.1);' />
92
- </div>
93
- """
94
- )
95
- # FIXED: Removed the unsupported 'horizontal_alignment' argument.
96
- with gr.Column():
97
- gr.Markdown("<h3 style='text-align: center;'>Upload Your Text File</h3>")
98
- chat_file_upload = gr.File(label="Upload any .txt file", file_types=[".txt"])
99
- lets_chat_button = gr.Button("💬 Process File and Start Chatting 💬", variant="primary")
100
-
101
- with gr.Column(visible=False) as chat_page:
102
- gr.Markdown("<h1 style='text-align: center;'>Chat with your Document</h1>")
103
- chatbot_ui = gr.Chatbot(height=600, bubble_full_width=False)
104
- with gr.Row():
105
- user_input_box = gr.Textbox(placeholder="Ask a question about your file...", scale=5)
106
- submit_button = gr.Button("Send", variant="primary", scale=1)
107
-
108
- def go_to_chat(current_state, chat_file, progress=gr.Progress()):
109
- if chat_file is None:
110
- raise gr.Error("A file must be uploaded.")
111
-
112
- progress(0, desc="Reading file...")
113
- with open(chat_file.name, 'r', encoding='utf-8') as f:
114
- content = f.read()
115
-
116
- progress(0.2, desc="Chunking text...")
117
- chunks = chunk_text(content)
118
-
119
- progress(0.5, desc="Creating embeddings... (This might take a moment)")
120
- embeddings = embedder.encode(chunks, show_progress_bar=True)
121
-
122
- progress(0.8, desc="Building search index...")
123
- index = faiss.IndexFlatL2(embeddings.shape[1])
124
- index.add(np.array(embeddings).astype('float32'))
125
-
126
- new_state = {
127
- "index": index,
128
- "chunks": chunks
129
- }
130
-
131
- progress(1, desc="Done!")
132
- return (
133
- new_state,
134
- gr.Column(visible=False),
135
- gr.Column(visible=True)
136
- )
137
-
138
- lets_chat_button.click(
139
- fn=go_to_chat,
140
- inputs=[app_state, chat_file_upload],
141
- outputs=[app_state, welcome_page, chat_page]
142
- )
143
- submit_button.click(
144
- fn=process_chat_request,
145
- inputs=[user_input_box, chatbot_ui, app_state],
146
- outputs=[user_input_box, chatbot_ui]
147
- )
148
- user_input_box.submit(
149
- fn=process_chat_request,
150
- inputs=[user_input_box, chatbot_ui, app_state],
151
- outputs=[user_input_box, chatbot_ui]
152
  )
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  if __name__ == "__main__":
155
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import re
3
+ from sentence_transformers import SentenceTransformer, util
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
5
+ import torch
6
+
7
+ # --- Model Loading ---
8
+ # Load the sentence transformer model for creating embeddings
9
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
10
+
11
+ # Load the T5 model and tokenizer for question answering
12
+ qa_model_name = 'google/flan-t5-base'
13
+ qa_tokenizer = T5Tokenizer.from_pretrained(qa_model_name)
14
+ qa_model = T5ForConditionalGeneration.from_pretrained(qa_model_name)
15
+
16
+ # --- Global Variables ---
17
+ chat_history_embeddings = None
18
+ chat_lines = []
19
+
20
+ # --- Helper Functions ---
21
+ def process_chat_file(file):
 
 
 
22
  """
23
+ Reads and parses the uploaded WhatsApp chat file.
24
  """
25
+ global chat_history_embeddings, chat_lines
26
+ if file is None:
27
+ return "Please upload a file first.", []
 
 
 
 
28
 
29
  try:
30
+ # Read the file content
31
+ with open(file.name, 'r', encoding='utf-8') as f:
32
+ content = f.read()
33
+
34
+ # Simple line-based parsing (can be improved with regex for more complex formats)
35
+ # This regex is a basic attempt and might need to be adjusted for different WhatsApp export formats.
36
+ # It tries to capture lines that start with a date and time.
37
+ lines = re.split(r'\n(?=\[\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{1,2}:\d{1,2}\])', content)
38
 
39
+ # Filter out empty lines and system messages
40
+ chat_lines = [line.strip() for line in lines if line.strip() and ":" in line]
41
 
42
+ if not chat_lines:
43
+ return "Could not find any chat messages in the file. Please check the file format.", []
 
44
 
45
+ # Create embeddings for the chat history
46
+ chat_history_embeddings = embedding_model.encode(chat_lines, convert_to_tensor=True)
 
 
47
 
48
+ return "File processed successfully! You can now ask questions.", []
49
+ except Exception as e:
50
+ return f"An error occurred: {e}", []
51
 
52
+ def get_bot_response(user_message, history, temperature):
53
+ """
54
+ Generates a response from the chatbot.
55
+ """
56
+ global chat_history_embeddings, chat_lines
57
 
58
+ if chat_history_embeddings is None:
59
+ return "Please upload and process a chat file first."
 
 
 
 
 
60
 
61
+ # 1. Find relevant context from the chat history
62
+ question_embedding = embedding_model.encode(user_message, convert_to_tensor=True)
63
+ cos_scores = util.pytorch_cos_sim(question_embedding, chat_history_embeddings)[0]
64
+
65
+ # Get the top 5 most similar chat lines
66
+ top_k = min(5, len(chat_lines))
67
+ top_results = torch.topk(cos_scores, k=top_k)
68
+
69
+ context = ""
70
+ for score, idx in zip(top_results[0], top_results[1]):
71
+ context += chat_lines[idx] + "\n"
72
+
73
+ # 2. Generate an answer using the T5 model
74
+ prompt = f"""
75
+ Answer the following question based on the provided chat history.
76
+ If the answer is not in the context, say "I couldn't find an answer to that in the chat history."
77
+
78
+ Chat History:
79
+ {context}
80
+
81
+ Question: {user_message}
82
+
83
+ Answer:
84
+ """
85
+
86
+ input_ids = qa_tokenizer.encode(prompt, return_tensors='pt')
87
+
88
+ # Generate the output
89
+ output_ids = qa_model.generate(
90
+ input_ids,
91
+ max_length=150,
92
+ num_beams=4,
93
+ temperature=temperature,
94
+ early_stopping=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
96
 
97
+ answer = qa_tokenizer.decode(output_ids[0], skip_special_tokens=True)
98
+
99
+ return answer
100
+
101
+ # --- Gradio UI ---
102
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="orange")) as demo:
103
+ gr.Markdown("# 💬 Chat with your WhatsApp History")
104
+ gr.Markdown("Upload your WhatsApp chat `.txt` file and ask questions about it!")
105
+
106
+ # Fun GIF
107
+ gr.HTML("""
108
+ <div style="text-align: center;">
109
+ <img src="https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaDB2d2k5eXNoc2FqZzNqZzZqenp2cDIzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZ-/media/k-pop/images/bts-oppas-and-hyungs-and-dongsaengs-and-no.gif" alt="Chatbot GIF" style="width:300px; height:auto; border-radius: 15px;">
110
+ </div>
111
+ """)
112
+
113
+ with gr.Row():
114
+ with gr.Column(scale=1):
115
+ file_upload = gr.File(label="Upload WhatsApp Chat (.txt)")
116
+ process_button = gr.Button("Process File")
117
+ upload_status = gr.Textbox(label="Status", interactive=False)
118
+
119
+ temperature_slider = gr.Slider(
120
+ minimum=0.1,
121
+ maximum=1.0,
122
+ value=0.1,
123
+ step=0.1,
124
+ label="Temperature",
125
+ info="Lower values are more accurate, higher values are more creative."
126
+ )
127
+
128
+ with gr.Column(scale=2):
129
+ chatbot = gr.Chatbot(label="Chat")
130
+ msg = gr.Textbox(label="Your Question")
131
+ clear = gr.ClearButton([msg, chatbot])
132
+
133
+ # --- Event Handlers ---
134
+ file_upload.upload(process_chat_file, inputs=[file_upload], outputs=[upload_status, chatbot])
135
+ process_button.click(process_chat_file, inputs=[file_upload], outputs=[upload_status, chatbot])
136
+ msg.submit(get_bot_response, [msg, chatbot, temperature_slider], [msg, chatbot])
137
+
138
  if __name__ == "__main__":
139
+ demo.launch(debug=True)