RAG chatbot using llama3

Community Article Published April 22, 2024

DatasetdatasetsEmbededQueryQueryRetrievedDocumentsFormatPromptOutputEmbeddedDatasetmixedbread-ai/mxbai-embed-large-v1mixedbread-ai/mxbai-embed-large-v1datasetsfaissmeta-llama/Meta-Llama-3-8B-Instruct

Definition

First let's define what's RAG: Retrieval-Augmented Generation. It's a technique used in natural language processing (NLP) to improve the performance of language models by incorporating external knowledge sources, such as databases or search engines. The basic idea is to retrieve relevant information from an external source based on the input query

Tools

for this blog, we require the following libraries:

pip install -q datasets sentence-transformers faiss-cpu accelerate

Embed the original dataset

This is a really slow process so we advise you to select a GPU

this is a necessary step and it is by far the slowest one on our list, we recommend you embed your dataset and save it/push it to the hub to avoid doing it every time

let's start by loading our original dataset

from datasets import load_dataset

dataset = load_dataset("not-lain/wikipedia")
dataset # Let's checkout our dataset
>>> DatasetDict({
    train: Dataset({
        features: ['id', 'url', 'title', 'text'],
        num_rows: 3000
    })
})

then we load our embedding model, I'm going to go with mixedbread-ai/mxbai-embed-large-v1

from sentence_transformers import SentenceTransformer
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

now let's embed our dataset

def embed(batch):
    """
    adds a column to the dataset called 'embeddings'
    """
    # or you can combine multiple columns here
    # For example the title and the text
    information = batch["text"]
    return {"embeddings" : ST.encode(information)}

dataset = dataset.map(embed,batched=True,batch_size=16)

It is advised that you save your dataset to avoid going through this step each time

to keep the original dataset for all users intact, I will push the embedded one to a new branch, this can be easily made using the revision parameter

dataset.push_to_hub("not-lain/wikipedia", revision="embedded")

Search through the dataset

you call your dataset from the hub

from datasets import load_dataset

dataset = load_dataset("not-lain/wikipedia",revision = "embedded")

then add the faiss index using the embeddings column that we created

data = dataset["train"]
data = data.add_faiss_index("embeddings")

let's define a search function

def search(query: str, k: int = 3 ):
    """a function that embeds a new query and returns the most probable results"""
    embedded_query = ST.encode(query) # embed new query
    scores, retrieved_examples = data.get_nearest_examples( # retrieve results
        "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
        k=k # get only top k results
    )
    return scores, retrieved_examples
# search for word anarchy and get the best 4 matching values from the dataset
scores , result = search("anarchy", 4 ) 
result['title']
>>> ['Anarchism', 'Anarcho-capitalism', 'Community', 'Capitalism']
print(result["text"][0])
>>>"Anarchism is a political philosophy and movement that is skeptical of all justifications for authority and (...)"

RAG chatbot

the following work is a draft of what an RAG chatbot might look like :

embed (only once)
│
└── new query
    │
    └── retrieve
        │
        └─── format prompt
            │
            └── GenAI
                │
                └── generate response

Now let's strap everything together in a new session after embedding:

pip install -q datasets sentence-transformers faiss-cpu accelerate bitsandbytes
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

dataset = load_dataset("not-lain/wikipedia",revision = "embedded")

data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset

def search(query: str, k: int = 3 ):
    """a function that embeds a new query and returns the most probable results"""
    embedded_query = ST.encode(query) # embed new query
    scores, retrieved_examples = data.get_nearest_examples( # retrieve results
        "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
        k=k # get only top k results
    )
    return scores, retrieved_examples
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config
)
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

we recommend you setup a system prompt to guide the LLM in generating responses.

SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer."""
def format_prompt(prompt,retrieved_documents,k):
  """using the retrieved documents we will prompt the model to generate our responses"""
  PROMPT = f"Question:{prompt}\nContext:"
  for idx in range(k) :
    PROMPT+= f"{retrieved_documents['text'][idx]}\n"
  return PROMPT

def generate(formatted_prompt):
  formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
  messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
  # tell the model to generate
  input_ids = tokenizer.apply_chat_template(
      messages,
      add_generation_prompt=True,
      return_tensors="pt"
  ).to(model.device)
  outputs = model.generate(
      input_ids,
      max_new_tokens=1024,
      eos_token_id=terminators,
      do_sample=True,
      temperature=0.6,
      top_p=0.9,
  )
  response = outputs[0][input_ids.shape[-1]:]
  return tokenizer.decode(response, skip_special_tokens=True)

def rag_chatbot(prompt:str,k:int=2):
  scores , retrieved_documents = search(prompt, k)
  formatted_prompt = format_prompt(prompt,retrieved_documents,k)
  return generate(formatted_prompt)
rag_chatbot("what's anarchy ?", k = 2)
>>>"So, anarchism is a political philosophy that questions the need for authority and hierarchy, and (...)"

APP

final code used to create the app can be found here

Dedication

in loving memory of Rayner V. Giuret, a friend, a brother, and an idol to all of us at LowRes.
Your legacy lives on in our hearts and minds. Thanks for everything.

Rest in peace, Rayner.