File size: 2,836 Bytes
4f5fe85
 
 
 
 
 
 
 
f30f794
4f5fe85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9065fba
 
 
4f5fe85
 
 
f30f794
9065fba
 
4f5fe85
 
 
 
 
 
 
 
 
 
 
 
ce2c548
 
4f5fe85
f30f794
4f5fe85
 
f30f794
 
5191bb0
f30f794
 
 
 
 
 
 
 
 
 
4f5fe85
 
 
3dd00c1
4f5fe85
f30f794
5191bb0
4f5fe85
 
 
 
3b7e2b7
69e7bc7
 
 
4f5fe85
 
b55a7b6
4f5fe85
a50f4c4
4f5fe85
 
 
 
b55a7b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os

import gradio as gr
import pinecone
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferWindowMemory
from langchain.vectorstores import Pinecone
from torch import cuda

LLAMA_2_7B_CHAT_HF_FRANC_V0_9 = os.environ.get("LLAMA_2_7B_CHAT_HF_FRANC_V0_9")
HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN")
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT')

# Set up Pinecone vector store

pinecone.init(
    api_key=PINECONE_API_KEY,
    environment=PINECONE_ENVIRONMENT
)

index_name = 'stadion-6237'
index = pinecone.Index(index_name)
embedding_model_id = 'sentence-transformers/paraphrase-mpnet-base-v2'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
embedding_model = HuggingFaceEmbeddings(
    model_name=embedding_model_id,
    model_kwargs={'device': device},
    encode_kwargs={'device': device, 'batch_size': 32}
)
text_key = 'text'

vector_store = Pinecone(
    index, embedding_model.embed_query, text_key
)

B_INST, E_INST = "[INST] ", " [/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"


def get_prompt_template(instruction, system_prompt):
    system_prompt = B_SYS + system_prompt + E_SYS
    prompt_template = B_INST + system_prompt + instruction + E_INST
    return prompt_template


template = get_prompt_template(
    """Use the following context to answer the question at the end.

Context:

{context}

Question: {question}""",
    """Reply in 10 sentences or less.
Do not use emotes."""
)

endpoint_url = (
    LLAMA_2_7B_CHAT_HF_FRANC_V0_9
)

llm = HuggingFaceEndpoint(
    endpoint_url=endpoint_url,
    huggingfacehub_api_token=HUGGING_FACE_HUB_TOKEN,
    task="text-generation",
    model_kwargs={
        "max_new_tokens": 512,
        "temperature": 0.1,
        "repetition_penalty": 1.1,
        "return_full_text": True,
    },
)

prompt = PromptTemplate(
    template=template,
    input_variables=["context", "question"]
)

memory = ConversationBufferWindowMemory(
    k=3,
    memory_key="history",
    input_key="question",
    ai_prefix="Franc",
    human_prefix="Runner",
)

rag_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type='stuff',
    retriever=vector_store.as_retriever(search_kwargs={'k': 4}),
    chain_type_kwargs={
        "prompt": prompt,
        # "memory": memory,
    },
)


def generate(message, history):
    reply = rag_chain(message)

    return reply['result'].strip()


gr.ChatInterface(
    generate,
    title="Franc v1.0",
    theme=gr.themes.Soft(),
    submit_btn="Ask Franc",
    retry_btn="Do better, Franc!",
    autofocus=True,
).queue().launch()