Kathirsci commited on
Commit
1e615dd
·
verified ·
1 Parent(s): 3b8853c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -25
app.py CHANGED
@@ -1,25 +1,182 @@
1
- from transformers import pipeline
2
- from huggingface_hub import snapshot_download
3
- from pathlib import Path
4
-
5
- # Define the messages for the chatbot with the pirate persona
6
- messages = [
7
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
8
- {"role": "user", "content": "Who are you?"},
9
- ]
10
-
11
- # Initialize the chatbot pipeline
12
- chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3")
13
-
14
- # Generate the response from the chatbot
15
- response = chatbot(messages)
16
- print(response)
17
-
18
- # Define the path to download the model files
19
- mistral_models_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
20
- mistral_models_path.mkdir(parents=True, exist_ok=True)
21
-
22
- # Download the model files
23
- snapshot_download(repo_id="mistralai/Mistral-7B-Instruct-v0.3",
24
- allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"],
25
- local_dir=mistral_models_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import logging
5
+ import subprocess
6
+ from typing import List
7
+ from langchain_community.document_loaders import PyPDFLoader
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain.schema import Document
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.text_splitter import CharacterTextSplitter
14
+ from langchain.runnables import RunnableMap, RunnableLambda
15
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
16
+
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Constants
23
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
24
+ EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
25
+ DEFAULT_MODEL = "google/flan-t5-large" # Replace with your preferred Hugging Face model
26
+
27
+ # Default model parameters
28
+ DEFAULT_PARAMS = {
29
+ "temperature": 0.7,
30
+ "top_p": 1.0,
31
+ "num_ctx": 4096,
32
+ "repeat_penalty": 1.1,
33
+ }
34
+
35
+ def get_default_value(param_name: str, default: float) -> float:
36
+ """Safely get a float value from DEFAULT_PARAMS."""
37
+ value = DEFAULT_PARAMS.get(param_name, default)
38
+ return float(value) if not isinstance(value, list) else float(value[0]) if value else default
39
+
40
+ @st.cache_resource
41
+ def load_embeddings():
42
+ """Load and cache the embedding model."""
43
+ try:
44
+ return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL, model_kwargs={'device': 'cpu'})
45
+ except Exception as e:
46
+ logger.error(f"Failed to load embeddings: {e}")
47
+ st.error("Failed to load the embedding model. Please try again later.")
48
+ return None
49
+
50
+ @st.cache_resource
51
+ def load_llm(model_name: str):
52
+ """Load and cache the Hugging Face model and tokenizer."""
53
+ try:
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
56
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
57
+ return summarizer
58
+ except Exception as e:
59
+ logger.error(f"Failed to load LLM: {e}")
60
+ st.error(f"Failed to load the model {model_name}. Please check the model name and try again.")
61
+ return None
62
+
63
+ def process_pdf(file) -> List[Document]:
64
+ try:
65
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
66
+ temp_file.write(file.getvalue())
67
+ temp_file_path = temp_file.name
68
+ loader = PyPDFLoader(file_path=temp_file_path)
69
+ documents = loader.load() # This loads each page as a separate Document
70
+ os.unlink(temp_file_path) # Clean up the temporary file
71
+ return documents
72
+ except Exception as e:
73
+ logger.error(f"Error processing PDF: {e}")
74
+ st.error("Failed to process the PDF. Please make sure it's a valid PDF file.")
75
+ return []
76
+
77
+ def create_vector_store(documents: List[Document], embeddings):
78
+ """Create and save the vector store."""
79
+ try:
80
+ db = FAISS.from_documents(documents, embeddings)
81
+ db.save_local(DB_FAISS_PATH)
82
+ return db
83
+ except Exception as e:
84
+ logger.error(f"Error creating vector store: {e}")
85
+ st.error("Failed to create the vector store. Please try again.")
86
+ return None
87
+
88
+ def summarize_report(documents: List[Document], summarizer) -> str:
89
+ """Summarize the report using a map-reduce approach."""
90
+ try:
91
+ # Limit the number of chunks to process
92
+ max_chunks = 50 # Adjust this value based on your needs
93
+ if len(documents) > max_chunks:
94
+ st.warning(f"Document is very large. Summarizing first {max_chunks} chunks only.")
95
+ documents = documents[:max_chunks]
96
+
97
+ # Map prompt
98
+ def map_fn(text):
99
+ summary = summarizer(text, max_length=150, min_length=40, do_sample=False)[0]['summary_text']
100
+ return summary
101
+
102
+ # Reduce prompt
103
+ def reduce_fn(summaries):
104
+ combined_text = " ".join(summaries)
105
+ final_summary = summarizer(combined_text, max_length=300, min_length=100, do_sample=False)[0]['summary_text']
106
+ return final_summary
107
+
108
+ # RunnableSequence replaces the deprecated LLMChain
109
+ map_chain = RunnableMap(
110
+ llm_chain=lambda text: map_fn(text)
111
+ )
112
+
113
+ reduce_chain = RunnableLambda(
114
+ llm_chain=lambda doc_summaries: reduce_fn(doc_summaries)
115
+ )
116
+
117
+ with st.spinner("Generating summary..."):
118
+ # Run map-reduce sequence
119
+ summaries = map_chain.run([doc.page_content for doc in documents])
120
+ summary = reduce_chain.run({"doc_summaries": summaries})
121
+
122
+ return summary
123
+
124
+ except Exception as e:
125
+ logger.error(f"Error summarizing report: {e}")
126
+ st.error("Failed to summarize the report. Please try again.")
127
+ return ""
128
+
129
+ def main():
130
+ st.title("Report Summarizer ")
131
+
132
+ model_option = st.sidebar.text_input("Enter Hugging Face model name", value=DEFAULT_MODEL)
133
+
134
+ # Advanced options
135
+ with st.sidebar.expander("Advanced Model Parameters"):
136
+ custom_temp = st.slider("Temperature", 0.0, 1.0,
137
+ value=get_default_value("temperature", 0.7),
138
+ step=0.01)
139
+ custom_top_p = st.slider("Top P", 0.0, 1.0,
140
+ value=get_default_value("top_p", 1.0),
141
+ step=0.01)
142
+ custom_num_ctx = st.number_input("Context Window", 1024, 8192,
143
+ value=int(get_default_value("num_ctx", 4096)))
144
+ custom_repeat_penalty = st.slider("Repeat Penalty", 1.0, 2.0,
145
+ value=get_default_value("repeat_penalty", 1.1),
146
+ step=0.01)
147
+
148
+ custom_params = {
149
+ "temperature": custom_temp,
150
+ "top_p": custom_top_p,
151
+ "num_ctx": custom_num_ctx,
152
+ "repeat_penalty": custom_repeat_penalty
153
+ }
154
+
155
+ uploaded_file = st.sidebar.file_uploader("Upload your Report", type="pdf")
156
+
157
+ summarizer = load_llm(model_option)
158
+ embeddings = load_embeddings()
159
+
160
+ if not summarizer or not embeddings:
161
+ return
162
+
163
+ if uploaded_file:
164
+ with st.spinner("Processing PDF..."):
165
+ documents = process_pdf(uploaded_file)
166
+
167
+ if documents:
168
+ with st.spinner("Creating vector store..."):
169
+ db = create_vector_store(documents, embeddings)
170
+
171
+ if db and st.button("Summarize"):
172
+ with st.spinner(f"Generating structured summary using {model_option}..."):
173
+ summary = summarize_report(documents, summarizer)
174
+
175
+ if summary:
176
+ st.subheader("Structured Summary:")
177
+ st.markdown(summary)
178
+ else:
179
+ st.warning("Failed to generate summary. Please try again.")
180
+
181
+ if __name__ == "__main__":
182
+ main()