File size: 3,310 Bytes
3116721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeb00a0
 
383c59f
eeb00a0
 
33b41aa
3116721
 
 
 
 
 
33b41aa
3116721
33b41aa
3116721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d24a8c0
3116721
 
 
33b41aa
3116721
 
 
 
 
 
 
33b41aa
3116721
33b41aa
 
3116721
33b41aa
 
 
 
 
 
3116721
33b41aa
3116721
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

# logging
import logging

# access .env file
import os
from dotenv import load_dotenv

import time

#boto3 for S3 access
import boto3
from botocore import UNSIGNED
from botocore.client import Config

# HF libraries
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import HuggingFaceHubEmbeddings
# vectorestore
from langchain_community.vectorstores import Chroma
from langchain_community.vectorstores import FAISS
import zipfile

# retrieval chain
from langchain.chains import RetrievalQAWithSourcesChain
# prompt template
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory

# github issues
#from langchain.document_loaders import GitHubIssuesLoader
# debugging
from langchain.globals import set_verbose
# caching
from langchain.globals import set_llm_cache
# We can do the same thing with a SQLite cache
from langchain.cache import SQLiteCache


# template for prompt
from prompt import template



set_verbose(True)

# set up logging for the chain
logging.basicConfig()
logging.getLogger("langchain.retrievers").setLevel(logging.INFO)    
logging.getLogger("langchain.chains.qa_with_sources").setLevel(logging.INFO)    

# load .env variables
config = load_dotenv(".env")
HUGGINGFACEHUB_API_TOKEN=os.getenv('HUGGINGFACEHUB_API_TOKEN')
AWS_S3_LOCATION=os.getenv('AWS_S3_LOCATION')
AWS_S3_FILE=os.getenv('AWS_S3_FILE')
VS_DESTINATION=os.getenv('VS_DESTINATION')

# remove old vectorstore
if os.path.exists(VS_DESTINATION):
    os.remove(VS_DESTINATION)

# remove old sqlite cache
if os.path.exists('.langchain.sqlite'):
    os.remove('.langchain.sqlite')



# initialize Model config
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.1"

# changed named to model_id to llm as is common
llm = HuggingFaceHub(repo_id=llm_model_name, model_kwargs={
    # "temperature":0.1, 
    "max_new_tokens":1024, 
    "repetition_penalty":1.2, 
#    "streaming": True, 
    "return_full_text":False
    })

# initialize Embedding config
embedding_model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
embeddings = HuggingFaceHubEmbeddings(repo_id=embedding_model_name)

set_llm_cache(SQLiteCache(database_path=".langchain.sqlite"))

# retrieve vectorsrore
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))

## download vectorstore from S3
s3.download_file(AWS_S3_LOCATION, AWS_S3_FILE, VS_DESTINATION)
with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
    zip_ref.extractall('./vectorstore/')

FAISS_INDEX_PATH='./vectorstore/lc-faiss-multi-qa-mpnet'
db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)

# use the cached embeddings instead of embeddings to speed up re-retrival
# db = Chroma(persist_directory="./vectorstore", embedding_function=embeddings)
# db.get()

retriever = db.as_retriever(search_type="mmr", search_kwargs={'k': 3, 'lambda_mult': 0.25})

prompt = PromptTemplate(
    input_variables=["history", "context", "question"],
    template=template,
)
memory = ConversationBufferMemory(memory_key="history", input_key="question")

qa = RetrievalQAWithSourcesChain.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
    "verbose": True,
    "memory": memory,
    "prompt": prompt,
    "document_variable_name": "context"
}
    )