|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI |
|
import PyPDF2 |
|
import os |
|
import gradio as gr |
|
import google.generativeai as genai |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from deep_translator import GoogleTranslator |
|
|
|
print('程式初始化') |
|
|
|
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
|
|
|
|
llm_model = 'gemini-1.5-flash' |
|
|
|
embeddings_model = "models/embedding-001" |
|
pdf_dir = 'data' |
|
|
|
|
|
print('-' * 21, '讀取資料', '-' * 21) |
|
docs = "" |
|
for filename in os.listdir(pdf_dir): |
|
if filename.endswith('.pdf'): |
|
try: |
|
with open(os.path.join(pdf_dir, filename), 'rb') as pdf_file: |
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
for i in range(len(pdf_reader.pages)): |
|
page = pdf_reader.pages[i] |
|
docs += page.extract_text() |
|
print('讀取成功:',filename) |
|
except: |
|
print('讀取失敗:',filename) |
|
|
|
|
|
print('-' * 21, '讀取完成', '-' * 21) |
|
|
|
if docs: |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) |
|
texts = text_splitter.split_text(docs) |
|
|
|
|
|
embeddings = GoogleGenerativeAIEmbeddings( |
|
model=embeddings_model, google_api_key=os.getenv("GOOGLE_API_KEY") |
|
) |
|
retriever = Chroma.from_texts(texts, embeddings).as_retriever(search_kwargs={"k": 5}) |
|
print('分割文本完成') |
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
model=llm_model, temperature=0, google_api_key=os.getenv("GOOGLE_API_KEY") |
|
) |
|
print('模型載入完成') |
|
|
|
|
|
|
|
def translate_to_english(text): |
|
return GoogleTranslator(source='auto', target='en').translate(text) |
|
|
|
def translate_to_chinese(text): |
|
return GoogleTranslator(source='auto', target='zh-TW').translate(text) |
|
|
|
|
|
|
|
|
|
chat_history = [] |
|
|
|
def invoke(question): |
|
print('invoke 函數觸發') |
|
if docs: |
|
system_prompt = ( |
|
"You are an assistant for question-answering tasks. " |
|
"Use the following pieces of retrieved context to answer the question. " |
|
|
|
) |
|
|
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, retriever=retriever |
|
) |
|
|
|
|
|
question = translate_to_english(question) |
|
|
|
response = qa_chain.invoke({"question": question, "chat_history": chat_history}) |
|
|
|
response = translate_to_chinese(response['answer']) |
|
|
|
|
|
|
|
|
|
else: |
|
response = 'No context!' |
|
|
|
return response |
|
|
|
|
|
|
|
description = "Gradio UI using the Gemini-1.5-Flash model for RAG." |
|
print('description') |
|
gr.close_all() |
|
|
|
demo = gr.Interface( |
|
fn=invoke, |
|
inputs=gr.Textbox(label="Question", lines=5), |
|
outputs=gr.Textbox(label="Response", lines=5), |
|
title="Gemini-RAG", |
|
description=description |
|
) |
|
|
|
demo.launch(share=True) |
|
|