CCCDev commited on
Commit
2e1abdd
·
verified ·
1 Parent(s): 033eeb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -3,18 +3,19 @@ from langchain_community.document_loaders import PyPDFLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain.chains import ConversationalRetrievalChain
6
- from langchain_huggingface import HuggingFaceEmbeddings
7
- from langchain.chains import ConversationChain
8
  from langchain.memory import ConversationBufferMemory
9
 
10
  from pathlib import Path
11
  import chromadb
12
  from unidecode import unidecode
 
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
  import re
15
 
16
  # Constants
17
- LLM_MODEL = "t5-small" # Changed to a Seq2Seq model compatible with AutoModelForSeq2SeqLM
 
18
  DB_CHUNK_SIZE = 512
19
  CHUNK_OVERLAP = 24
20
  TEMPERATURE = 0.1
@@ -43,13 +44,12 @@ def create_db(splits, collection_name):
43
  return vectordb
44
 
45
  # Initialize langchain LLM chain
46
- def initialize_llmchain(llm_model, vector_db, progress=gr.Progress()):
47
  progress(0.5, desc="Initializing HF Hub...")
48
-
49
- # Create the HuggingFacePipeline for the model
50
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
51
  model = AutoModelForSeq2SeqLM.from_pretrained(llm_model)
52
- pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
53
 
54
  progress(0.75, desc="Defining buffer memory...")
55
  memory = ConversationBufferMemory(
@@ -95,8 +95,8 @@ def initialize_database(pdf_url, chunk_size, chunk_overlap, progress=gr.Progress
95
  progress(0.9, desc="Done!")
96
  return vector_db, collection_name, "Complete!"
97
 
98
- def initialize_LLM(vector_db, progress=gr.Progress()):
99
- qa_chain = initialize_llmchain(LLM_MODEL, vector_db, progress)
100
  return qa_chain, "Complete!"
101
 
102
  def format_chat_history(message, chat_history):
@@ -165,7 +165,7 @@ def demo():
165
 
166
  def auto_initialize():
167
  vector_db, collection_name, db_status = initialize_database(pdf_url, DB_CHUNK_SIZE, CHUNK_OVERLAP)
168
- qa_chain, llm_status = initialize_LLM(vector_db)
169
  return vector_db, collection_name, db_status, qa_chain, llm_status, "Initialization complete."
170
 
171
  demo.load(auto_initialize, [], [vector_db, collection_name, db_progress, qa_chain, llm_progress])
 
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain.chains import ConversationalRetrievalChain
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
 
7
  from langchain.memory import ConversationBufferMemory
8
 
9
  from pathlib import Path
10
  import chromadb
11
  from unidecode import unidecode
12
+
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
  import re
15
 
16
  # Constants
17
+ LLM_MODEL = "facebook/bart-large-cnn" # Changed to a model with larger response capabilities
18
+ LLM_MAX_TOKEN = 512
19
  DB_CHUNK_SIZE = 512
20
  CHUNK_OVERLAP = 24
21
  TEMPERATURE = 0.1
 
44
  return vectordb
45
 
46
  # Initialize langchain LLM chain
47
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
48
  progress(0.5, desc="Initializing HF Hub...")
49
+
 
50
  tokenizer = AutoTokenizer.from_pretrained(llm_model)
51
  model = AutoModelForSeq2SeqLM.from_pretrained(llm_model)
52
+ pipe = pipeline("summarization", model=model, tokenizer=tokenizer)
53
 
54
  progress(0.75, desc="Defining buffer memory...")
55
  memory = ConversationBufferMemory(
 
95
  progress(0.9, desc="Done!")
96
  return vector_db, collection_name, "Complete!"
97
 
98
+ def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
99
+ qa_chain = initialize_llmchain(LLM_MODEL, llm_temperature, max_tokens, top_k, vector_db, progress)
100
  return qa_chain, "Complete!"
101
 
102
  def format_chat_history(message, chat_history):
 
165
 
166
  def auto_initialize():
167
  vector_db, collection_name, db_status = initialize_database(pdf_url, DB_CHUNK_SIZE, CHUNK_OVERLAP)
168
+ qa_chain, llm_status = initialize_LLM(TEMPERATURE, LLM_MAX_TOKEN, 20, vector_db)
169
  return vector_db, collection_name, db_status, qa_chain, llm_status, "Initialization complete."
170
 
171
  demo.load(auto_initialize, [], [vector_db, collection_name, db_progress, qa_chain, llm_progress])