BibleGPT2 / app.py
manjunathshiva's picture
Update app.py
8e6f83d verified
from llama_index.core import (
VectorStoreIndex
)
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from typing import Any, List, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import streamlit as st
from llama_index.llms.huggingface import (
HuggingFaceInferenceAPI
)
import os
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
Q_END_POINT = os.environ.get("Q_END_POINT")
Q_API_KEY = os.environ.get("Q_API_KEY")
#DOC
#https://docs.llamaindex.ai/en/stable/examples/vector_stores/qdrant_hybrid.html
doc_tokenizer = AutoTokenizer.from_pretrained(
"naver/efficient-splade-VI-BT-large-doc"
)
doc_model = AutoModelForMaskedLM.from_pretrained(
"naver/efficient-splade-VI-BT-large-doc"
)
query_tokenizer = AutoTokenizer.from_pretrained(
"naver/efficient-splade-VI-BT-large-query"
)
query_model = AutoModelForMaskedLM.from_pretrained(
"naver/efficient-splade-VI-BT-large-query"
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
doc_model = doc_model.to(device)
query_model = query_model.to(device)
def sparse_doc_vectors(
texts: List[str],
) -> Tuple[List[List[int]], List[List[float]]]:
"""
Computes vectors from logits and attention mask using ReLU, log, and max operations.
"""
tokens = doc_tokenizer(
texts, truncation=True, padding=True, return_tensors="pt"
)
if torch.cuda.is_available():
tokens = tokens.to("cuda:1")
output = doc_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
tvecs, _ = torch.max(weighted_log, dim=1)
# extract the vectors that are non-zero and their indices
indices = []
vecs = []
for batch in tvecs:
indices.append(batch.nonzero(as_tuple=True)[0].tolist())
vecs.append(batch[indices[-1]].tolist())
return indices, vecs
def sparse_query_vectors(
texts: List[str],
) -> Tuple[List[List[int]], List[List[float]]]:
"""
Computes vectors from logits and attention mask using ReLU, log, and max operations.
"""
# TODO: compute sparse vectors in batches if max length is exceeded
tokens = query_tokenizer(
texts, truncation=True, padding=True, return_tensors="pt"
)
if torch.cuda.is_available():
tokens = tokens.to("cuda:1")
output = query_model(**tokens)
logits, attention_mask = output.logits, tokens.attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
tvecs, _ = torch.max(weighted_log, dim=1)
# extract the vectors that are non-zero and their indices
indices = []
vecs = []
for batch in tvecs:
indices.append(batch.nonzero(as_tuple=True)[0].tolist())
vecs.append(batch[indices[-1]].tolist())
return indices, vecs
st.header("Chat with the Bible docs 💬 📚")
if "messages" not in st.session_state.keys(): # Initialize the chat message history
st.session_state.messages = [
{"role": "assistant", "content": "Ask me a question about Bible!"}
]
# creates a persistant index to disk
client = QdrantClient(
Q_END_POINT,
api_key=Q_API_KEY,
)
# create our vector store with hybrid indexing enabled
# batch_size controls how many nodes are encoded with sparse vectors at once
vector_store = QdrantVectorStore(
"bible", client=client, enable_hybrid=True, batch_size=20,force_disable_check_same_thread=True,
sparse_doc_fn=sparse_doc_vectors,
sparse_query_fn=sparse_query_vectors,
)
llm = HuggingFaceInferenceAPI(
model_name="meta-llama/Meta-Llama-3-8B-Instruct",
token=HUGGINGFACEHUB_API_TOKEN,
context_window=8096,
)
Settings.llm = llm
Settings.tokenzier = AutoTokenizer.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct"
)
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5", device="cpu")
Settings.embed_model = embed_model
index = VectorStoreIndex.from_vector_store(vector_store=vector_store,embed_model=embed_model)
from llama_index.core.memory import ChatMemoryBuffer
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
chat_engine = index.as_chat_engine(chat_mode="condense_question",
verbose=True,
memory=memory,
sparse_top_k=10,
vector_store_query_mode="hybrid",
similarity_top_k=3,
)
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = chat_engine.chat(prompt)
st.write(response.response)
message = {"role": "assistant", "content": response.response}
st.session_state.messages.append(message) # Add response to message history