import spaces import gradio as gr from datasets import load_dataset import os from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig import torch from threading import Thread from sentence_transformers import SentenceTransformer import faiss import fitz # PyMuPDF # 환경 변수에서 Hugging Face 토큰 가져오기 token = os.environ.get("HF_TOKEN") # 임베딩 모델 로드 ST = SentenceTransformer("jhgan/ko-sroberta-multitask") # PDF에서 텍스트 추출 def extract_text_from_pdf(pdf_path): doc = fitz.open(pdf_path) text = "" for page in doc: text += page.get_text() return text # 법률 문서 PDF 경로 지정 및 텍스트 추출 pdf_path = "laws.pdf" # 여기에 실제 PDF 경로를 입력하세요. law_text = extract_text_from_pdf(pdf_path) # 법률 문서 텍스트를 문장 단위로 나누고 임베딩 law_sentences = law_text.split('\n') # Adjust splitting based on your PDF structure law_embeddings = ST.encode(law_sentences) # FAISS 인덱스 생성 및 임베딩 추가 index = faiss.IndexFlatL2(law_embeddings.shape[1]) index.add(law_embeddings) # Hugging Face에서 법률 상담 데이터셋 로드 dataset = load_dataset("jihye-moon/LawQA-Ko") data = dataset["train"] # 질문 컬럼을 임베딩하여 새로운 컬럼에 추가 data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True) data.add_faiss_index(column="question_embedding") # LLaMA 모델 설정 model_id = "google/gemma-2-27b-it" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=bnb_config, token=token ) SYS_PROMPT = """You are an assistant for answering legal questions. You are given the extracted parts of legal documents and a question. Provide a conversational answer. If you don't know the answer, just say "I do not know." Don't make up an answer. you must answer korean.""" # 법률 문서 검색 @spaces.Gpu def search_law(query, k=5): query_embedding = ST.encode([query]) D, I = index.search(query_embedding, k) return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])] # 법률 상담 데이터 검색 함수 @spaces.Gpu def search_qa(query, k=3): scores, retrieved_examples = data.get_nearest_examples( "question_embedding", ST.encode(query), k=k ) return [retrieved_examples["answer"][i] for i in range(k)] # 최종 프롬프트 생성 def format_prompt(prompt, law_docs, qa_docs): PROMPT = f"Question: {prompt}\n\nLegal Context:\n" for doc in law_docs: PROMPT += f"{doc[0]}\n" # Assuming doc[0] contains the relevant text PROMPT += "\nLegal QA:\n" for doc in qa_docs: PROMPT += f"{doc}\n" return PROMPT # 챗봇 응답 함수 @spaces.Gpu def talk(prompt, history): law_results = search_law(prompt, k=3) qa_results = search_qa(prompt, k=3) retrieved_law_docs = [result[0] for result in law_results] formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results) formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한 # Adjust the message roles messages = [{"role": "user", "content": SYS_PROMPT + "\n" + formatted_prompt}] input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, temperature=0.2, eos_token_id=tokenizer.eos_token_id, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Gradio 인터페이스 설정 TITLE = "Legal RAG Chatbot" DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation. This chatbot can search legal documents and previous legal QA pairs to provide answers.""" demo = gr.ChatInterface( fn=talk, chatbot=gr.Chatbot( show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False, ), theme="Soft", examples=[["What are the regulations on data privacy?"]], title=TITLE, description=DESCRIPTION, ) # Gradio 데모 실행 demo.launch(debug=True)