D2Cell-chatbot / langchain_qwen_run.py
kenghuoxiong's picture
Upload 8 files
9d2b8a1 verified
raw
history blame
4.33 kB
from langchain_community.chat_models import ChatOpenAI
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.schema import HumanMessage, SystemMessage
import os
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
import gradio as gr
import requests
from langchain_core.prompts import PromptTemplate
from qwen_api import qwen_api
def load_documents(directory='../langchain-database'):
loader = DirectoryLoader(directory, show_progress=True, use_multithreading=True, silent_errors=True)
documents = loader.load()
text_spliter = CharacterTextSplitter(chunk_size=2048, chunk_overlap=200)
split_docs = text_spliter.split_documents(documents)
print(len(split_docs))
return split_docs
def load_embedding_mode():
# embedding_model_dict = {"m3e-base": "/home/xiongwen/m3e-base"}
encode_kwargs = {"normalize_embeddings": False}
model_kwargs = {"device": 'cuda'}
return HuggingFaceEmbeddings(model_name="/home/xiongwen/bge-m3",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
def store_chroma(docs,embedding,persist_directory='./VecterStore2'):
db = Chroma.from_documents(docs, embedding)
# db.persist()
return db
def chat(question, history):
if len(history) == 0:
response = qa.invoke(question)['result']
else:
response = qwen_api(question, gradio_history=history)
return response
if __name__ == '__main__':
embedding = load_embedding_mode()
db = Chroma(persist_directory='/home/xiongwen/llama2-a40-ner/langchain-qwen/VecterStore2_512_txt/VecterStore2_512_txt', embedding_function=embedding)
os.environ["OPENAI_API_BASE"] = 'http://localhost:8000/v1'
os.environ["OPENAI_API_KEY"] = 'none'
llm = ChatOpenAI(
model="/home/xiongwen/Qwen1.5-110B-Chat",
temperature=0.8,)
prompt_template = """
{context}
The above content is a form of biological background knowledge. Please answer the questions according to the above content. Please be sure to answer the questions according to the background knowledge and attach the doi number of the information source when answering.
Question: {question}
Answer in English:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
# messages = [
# SystemMessage(content="you are an assistant in biology."),
# HumanMessage(content="which gene should be knocked to produce hyaluronic acid?")
# ]
# response = llm(messages)
# print('----------')
# print(response.content)
# print('----------')
# interface = gr.ChatInterface(chat)
# interface.launch(inbrowser=True)
retriever = db.as_retriever()
print(dir(retriever))
question = "which gene should be knocked in the process of producing ethanol in E.coli?"
# docs = retriever.get_relevant_documents(question, top_k=10)
# print(docs)
# docs = db.similarity_search(question, k=5)
# print(docs)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True
)
interface = gr.ChatInterface(
fn=chat,
chatbot=gr.Chatbot(height=800, bubble_full_width=False),
theme=gr.themes.Default(spacing_size='sm', radius_size='sm'),
examples=['which gene should be knocked in the process of producing ethanol in Saccharomyces cerevisiae?']
)
interface.launch(inbrowser=True)
# response = qa.invoke("which gene should be knocked in the process of producing ethanol in Saccharomyces cerevisiae?")
# # response = qa({"query": question})
# print('----------')
# print(response)
# print('----------')
# print(response['source_documents'])