RAG-test / main.py
woonchen's picture
Update main.py
df63c83 verified
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
import PyPDF2
import os
import gradio as gr
import google.generativeai as genai
from langchain.chains import ConversationalRetrievalChain
from langchain_huggingface import HuggingFaceEmbeddings
from deep_translator import GoogleTranslator
print('程式初始化')
# 設定 Google API 金鑰
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
# 選擇模型
llm_model = 'gemini-1.5-flash'
embeddings_model = "models/embedding-001"
pdf_dir = 'data'
# 讀取 PDF 檔案
print('-' * 21, '讀取資料', '-' * 21)
docs = ""
for filename in os.listdir(pdf_dir):
if filename.endswith('.pdf'):
try:
with open(os.path.join(pdf_dir, filename), 'rb') as pdf_file:
pdf_reader = PyPDF2.PdfReader(pdf_file)
for i in range(len(pdf_reader.pages)):
page = pdf_reader.pages[i]
docs += page.extract_text()
print('讀取成功:',filename)
except:
print('讀取失敗:',filename)
print('-' * 21, '讀取完成', '-' * 21)
# 分割文本
if docs:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
texts = text_splitter.split_text(docs)
# 建立嵌入模型和檢索器
embeddings = GoogleGenerativeAIEmbeddings(
model=embeddings_model, google_api_key=os.getenv("GOOGLE_API_KEY")
)
retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 5})
print('分割文本完成')
# 初始化 Gemini 模型
llm = ChatGoogleGenerativeAI(
model=llm_model, temperature=0, google_api_key=os.getenv("GOOGLE_API_KEY")
)
print('模型載入完成')
# 定義翻譯函數
def translate_to_english(text):
return GoogleTranslator(source='auto', target='en').translate(text)
def translate_to_chinese(text):
return GoogleTranslator(source='auto', target='zh-TW').translate(text)
# 定義 invoke 函數
# 初始化 chat_history 为空
chat_history = []
def invoke(question):
print('invoke 函數觸發')
if docs:
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer the question. "
)
#"If you don't know the answer, say that you don't know."
# 初始化 ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=retriever
)
# 调用链并传递 chat_history
question = translate_to_english(question)
# chat_history = translate_to_english(chat_history)
response = qa_chain.invoke({"question": question, "chat_history": chat_history})
# response = qa_chain.invoke({"question": question})
response = translate_to_chinese(response['answer'])
# 更新 chat_history,保留上下文
# chat_history += question
# chat_history += response
else:
response = 'No context!'
return response
# Gradio 介面配置
description = "Gradio UI using the Gemini-1.5-Flash model for RAG."
print('description')
gr.close_all()
demo = gr.Interface(
fn=invoke,
inputs=gr.Textbox(label="Question", lines=5),
outputs=gr.Textbox(label="Response", lines=5),
title="Gemini-RAG",
description=description
)
demo.launch(share=True)