BioGeo / app.py
RajatChaudhari's picture
imported spaces
e99a67e verified
import gradio as gr
from transformers import AutoTokenizer
import os
import spaces
import torch
from llama_index.llms.huggingface import HuggingFaceLLM
# Optional quantization to 4bit
from transformers import BitsAndBytesConfig
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
import faiss
from llama_index.core import (
load_index_from_storage,
StorageContext,
)
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.core.tools import QueryEngineTool, ToolMetadata
import json
from typing import Sequence, List
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import BaseTool, FunctionTool
from llama_index.core.agent import ReActAgent
import nest_asyncio
from llama_index.core.tools import QueryEngineTool, ToolMetadata
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Mistral 7B Instruct v0.3</h1>
<p>This Space demonstrates the Agent based RAG on multiple documents using Gemma 2b it and llama index</p>
</div>
'''
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-1.1-2b-it",
token=HF_TOKEN,
)
stopping_ids = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
quantization_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type = "nf4",
bnb_4bit_use_double_quant = True,
)
llm = HuggingFaceLLM(
model_name = "google/gemma-1.1-2b-it",
model_kwargs = {
"token": HF_TOKEN,
"torch_dtype": torch.bfloat16, # comment this line and uncomment below to use 4bit
#"quantization_config": quantization_config
},
generate_kwargs = {
"do_sample": True,
"temperature": 0.6,
"top_p": 0.9,
},
tokenizer_name = "google/gemma-1.1-2b-it",
tokenizer_kwargs = {"token": HF_TOKEN},
stopping_ids = stopping_ids,
)
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
# dimensions of bge-large-en-v1.5 obtained from https://huggingface.co/BAAI/bge-large-en-v1.5
d = 1024
faiss_index = faiss.IndexFlatL2(d)
nest_asyncio.apply()
# bge embedding model
Settings.embed_model = embed_model
# GPU - Llama-3-8B-Instruct model
# CPU - Gemma 1.1 2B it instruct
Settings.llm = llm
# rebuild storage context
geoVectorStore = FaissVectorStore.from_persist_dir("./geoindex/")
geoStorageContext = StorageContext.from_defaults(
vector_store=geoVectorStore, persist_dir="./geoindex/")
geoindex = load_index_from_storage(storage_context=geoStorageContext)
bioVectorStore = FaissVectorStore.from_persist_dir("./bioindex/")
bioStorageContext = StorageContext.from_defaults(
vector_store=bioVectorStore, persist_dir="./bioindex/")
bioindex = load_index_from_storage(storage_context=geoStorageContext)
geo_engine = geoindex.as_query_engine(similarity_top_k=3)
bio_engine = bioindex.as_query_engine(similarity_top_k=3)
query_engine_tools = [
QueryEngineTool(
query_engine=geo_engine,
metadata=ToolMetadata(
name="geography",
description=(
"This is a geography textbook, it provides information about geography. "
"Use a detailed plain text question as input to the tool."
),
),
),
QueryEngineTool(
query_engine=bio_engine,
metadata=ToolMetadata(
name="biology",
description=(
"This is a biology textbook it provides information about biology. "
"Use a detailed plain text question as input to the tool."
),
),
),
]
agent = ReActAgent.from_tools(
query_engine_tools,
llm=llm,
verbose=False,
)
@spaces.GPU(duration=120)
def respond(
message,
# history: list[tuple[str, str]],
# system_message,
# max_tokens,
# temperature,
# top_p,
):
prompt=f'''Analyze the question: {message} and use appropriate tool to get the relevant context and answer the question, do not answer on your own and output only Observation'''
response = agent.chat(prompt)
return print(str(response))
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
# additional_inputs=[
# gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(
# minimum=0.1,
# maximum=1.0,
# value=0.95,
# step=0.05,
# label="Top-p (nucleus sampling)",
# ),
# ],
examples=[
["What are different types of rural settlement?"],
["Explain Urbanisation in India?"],
["What was the level of urbanisation in India in 2011?"],
["List the religious and cultural towns in India?"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()