Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import openai | |
import pandas as pd | |
from langchain.vectorstores import FAISS | |
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
from langchain.chains import LLMChain | |
from langchain_core.output_parsers.string import StrOutputParser | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
#from langchain.embeddings import HuggingFaceBgeEmbeddings | |
import nltk | |
nltk.download('wordnet') | |
from sentence_transformers import SentenceTransformer | |
#embeddings = OpenAIEmbeddings() | |
#model_name = "BAAI/bge-large-en-v1.5" | |
#model_kwargs = {'device':'cuda'} | |
#encode_kwargs = {'normalize_embeddings':True} | |
#embedding_function = HuggingFaceBgeEmbeddings( | |
# model_name = model_name, | |
# model_kwargs = model_kwargs, | |
# encode_kwargs = encode_kwargs | |
#) | |
embedder = SentenceTransformer('all-mpnet-base-v2') | |
# Set the OpenAI API key | |
#openai.api_key = os.getenv("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") | |
# Load the FAISS index using LangChain's FAISS implementation | |
db = FAISS.load_local("Faiss_index", embedder, allow_dangerous_deserialization=True) | |
parser = StrOutputParser() | |
# Load your data (e.g., a DataFrame) | |
df = pd.read_pickle('df_news (1).pkl') | |
# Search function to retrieve relevant documents | |
def search(query): | |
query_embedding = embedder.embed_query(query).reshape(1, -1).astype('float32') | |
D, I = db.similarity_search_with_score(query_embedding, k=10) | |
results = [] | |
for idx in I[0]: | |
if idx < 3327: # Adjust this based on your indexing | |
doc_index = idx | |
results.append({ | |
'type': 'metadata', | |
'title': df.iloc[doc_index]['title'], | |
'author': df.iloc[doc_index]['author'], | |
'full_text': df.iloc[doc_index]['full_text'], | |
'source': df.iloc[doc_index]['url'] | |
}) | |
else: | |
chunk_index = idx - 3327 | |
metadata = metadata_info[chunk_index] | |
doc_index = metadata['index'] | |
chunk_text = metadata['chunk'] | |
results.append({ | |
'type': 'content', | |
'title': df.iloc[doc_index]['title'], | |
'author': df.iloc[doc_index]['author'], | |
'content': chunk_text, | |
'source': df.iloc[doc_index]['url'] | |
}) | |
return results | |
# Generate an answer based on the retrieved documents | |
def generate_answer(query): | |
context = search(query) | |
context_str = "\n\n".join([f"Title: {doc['title']}\nContent: {doc.get('content', doc.get('full_text', ''))}" for doc in context]) | |
prompt = f""" | |
Answer the question based on the context below. If you can't answer the question, answer with "I don't know". | |
Context: {context_str} | |
Question: {query} | |
""" | |
# Set up the ChatOpenAI model with temperature and other parameters | |
chat = ChatOpenAI( | |
model="gpt-4", | |
temperature=0.2, | |
max_tokens=1500, | |
api_key=openai.api_key | |
) | |
messages = [ | |
SystemMessagePromptTemplate.from_template("You are a helpful assistant."), | |
HumanMessagePromptTemplate.from_template(prompt) | |
] | |
chat_chain = LLMChain( | |
llm=chat, | |
prompt=ChatPromptTemplate.from_messages(messages) | |
) | |
# Get the response from the chat model | |
response = chat_chain.run(messages) | |
return response.strip() | |
# Gradio chat interface | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
response = generate_answer(message) | |
yield response | |
# Gradio demo setup | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |