franc-v0.9 / app.py
lrtherond's picture
Prompt revision
ce2c548
import os
import gradio as gr
import pinecone
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferWindowMemory
from langchain.vectorstores import Pinecone
from torch import cuda
LLAMA_2_7B_CHAT_HF_FRANC_V0_9 = os.environ.get("LLAMA_2_7B_CHAT_HF_FRANC_V0_9")
HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN")
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT')
# Set up Pinecone vector store
pinecone.init(
api_key=PINECONE_API_KEY,
environment=PINECONE_ENVIRONMENT
)
index_name = 'stadion-6237'
index = pinecone.Index(index_name)
embedding_model_id = 'sentence-transformers/paraphrase-mpnet-base-v2'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model_id,
model_kwargs={'device': device},
encode_kwargs={'device': device, 'batch_size': 32}
)
text_key = 'text'
vector_store = Pinecone(
index, embedding_model.embed_query, text_key
)
B_INST, E_INST = "[INST] ", " [/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
def get_prompt_template(instruction, system_prompt):
system_prompt = B_SYS + system_prompt + E_SYS
prompt_template = B_INST + system_prompt + instruction + E_INST
return prompt_template
template = get_prompt_template(
"""Use the following context to answer the question at the end.
Context:
{context}
Question: {question}""",
"""Reply in 10 sentences or less.
Do not use emotes."""
)
endpoint_url = (
LLAMA_2_7B_CHAT_HF_FRANC_V0_9
)
llm = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token=HUGGING_FACE_HUB_TOKEN,
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"temperature": 0.1,
"repetition_penalty": 1.1,
"return_full_text": True,
},
)
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
memory = ConversationBufferWindowMemory(
k=3,
memory_key="history",
input_key="question",
ai_prefix="Franc",
human_prefix="Runner",
)
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=vector_store.as_retriever(search_kwargs={'k': 4}),
chain_type_kwargs={
"prompt": prompt,
# "memory": memory,
},
)
def generate(message, history):
reply = rag_chain(message)
return reply['result'].strip()
gr.ChatInterface(
generate,
title="Franc v1.0",
theme=gr.themes.Soft(),
submit_btn="Ask Franc",
retry_btn="Do better, Franc!",
autofocus=True,
).queue().launch()