File size: 2,081 Bytes
5f0df75 8ea257e 5f0df75 8ea257e 5f0df75 c72c1d8 aa50b20 5f0df75 c72c1d8 5f0df75 ddcd2fe 59b03a4 c72c1d8 5f0df75 c72c1d8 5f0df75 8ea257e 5f0df75 8ea257e 5f0df75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import Pinecone
from langchain_community.embeddings import HuggingFaceEmbeddings
import os
openai_api_key = os.environ.get("OPENAI_API_KEY")
model_name = os.environ.get('MODEL_NAME', 'all-mpnet-base-v2')
pinecone_index = os.environ.get("PINECONE_INDEX")
class Conversation_RAG:
def __init__(self, model_name="gpt-3.5-turbo-instruct"):
self.model_name = model_name
def get_vectordb(self, pc):
index = pc.Index(pinecone_index)
embeddings = HuggingFaceEmbeddings(model_name=f"model/{model_name}")
vectordb = Pinecone(index, embeddings, "text")
return vectordb
def create_model(self, max_new_tokens=512, temperature=0.8):
llm = ChatOpenAI(
openai_api_key=openai_api_key,
model_name=self.model_name,
temperature=temperature,
max_tokens=max_new_tokens,
)
return llm
def create_conversation(self, model, vectordb, k_context=5, instruction="Use the following pieces of context to answer the question at the end by. Generate the answer based on the given context only. If you do not find any information related to the question in the given context, just say that you don't know, don't try to make up an answer. Keep your answer expressive."):
template = instruction + """
**Document Context Input**:\n
{context}\n
**Client Case Input**: {question}\n
"""
QCA_PROMPT = PromptTemplate(input_variables=["instruction", "context", "question"], template=template)
qa = ConversationalRetrievalChain.from_llm(
llm=model,
chain_type='stuff',
retriever=vectordb.as_retriever(search_kwargs={"k": k_context}),
combine_docs_chain_kwargs={"prompt": QCA_PROMPT},
get_chat_history=lambda h: h,
verbose=True
)
return qa |