import gradio as gr from transformers import pipeline from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings import os # Load the embedding model embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # Load the pre-existing vector database persist_directory = "db" vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings) # Load the Marco-o1 model pipe = pipeline("text-generation", model="AIDC-AI/Marco-o1", device_map="auto", torch_dtype="auto", trust_remote_code=True) def get_relevant_context(query, k=3): # Search the vector database for relevant documents docs = vectordb.similarity_search(query, k=k) # Combine the relevant documents into a single context string context = "\n".join([doc.page_content for doc in docs]) return context def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): messages = [system_message] # Get relevant context from the vector database context = get_relevant_context(message) # Add context to the system message if context: messages[0] = f"{system_message}\n\nRelevant context:\n{context}" for val in history: if val[0]: messages.append(val[0]) if val[1]: messages.append(val[1]) messages.append(message) # Combine all messages into one string input_text = "\n".join(messages) response = pipe( input_text, max_length=max_tokens + len(input_text), temperature=temperature, top_p=top_p, num_return_sequences=1 )[0]['generated_text'] # Extract new response new_response = response[len(input_text):].strip() yield new_response demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant. Use the provided context to answer questions accurately.", label="System message" ), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], title="Marco-O1 Assistant with Knowledge Base", description="Ask questions about the documents in the knowledge base. The assistant will use the relevant context to provide accurate answers." ) if __name__ == "__main__": demo.launch()