heesuuuuuu's picture
Update app.py
a869f0d verified
import os
os.environ["HF_HOME"] = "/home/user/.cache/huggingface"
from langchain_community.llms import HuggingFacePipeline
import torch
# beomi/gemma-ko-2b ๋ชจ๋ธ ์ง€์ •
model_id = "beomi/gemma-ko-2b"
llm = HuggingFacePipeline.from_model_id(
model_id=model_id,
task="text-generation",
device=None, # Explicitly set to None to avoid conflicts with device_map
model_kwargs={"torch_dtype": torch.bfloat16, "device_map": "auto"}
)
# %%
import os
os.environ['TRANSFORMERS_CACHE'] = '/data/heesu/huggingface_cache'
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import gradio as gr # Gradio ์ž„ํฌํŠธ ์ถ”๊ฐ€
# --- 1. Vector DB ๋ฐ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
print("Vector DB์™€ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘์ž…๋‹ˆ๋‹ค...")
index_path = "proj2_voca"
model_name = "jhgan/ko-sroberta-multitask"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
vectorstore = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
retriever = vectorstore.as_retriever(search_kwargs={'k': 2})
# --- 2. ์–ธ์–ด ๋ชจ๋ธ(LLM) ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
print("์–ธ์–ด ๋ชจ๋ธ(beomi/gemma-ko-2b)์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ์ค‘์ž…๋‹ˆ๋‹ค. ์‹œ๊ฐ„์ด ๋‹ค์†Œ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค...")
model_id = "beomi/gemma-ko-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=None # ์ž๋™ ๋ฐฐ์น˜ ๋น„ํ™œ์„ฑํ™”
).to("cpu") # GPU ๋ฒˆํ˜ธ ๋ช…์‹œ
# --- 2. ์–ธ์–ด ๋ชจ๋ธ(LLM) ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ---
# ... (์ด์ „ ์ฝ”๋“œ ์ƒ๋žต) ...
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1,
max_new_tokens=170, # ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜ ๊ฐ์†Œ
temperature=0.7,
repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€ ํŒจ๋„ํ‹ฐ ์ถ”๊ฐ€
pad_token_id=tokenizer.eos_token_id
)
# ... (์ดํ›„ ์ฝ”๋“œ ์ƒ๋žต) ...
llm = HuggingFacePipeline(pipeline=pipe)
# --- 3. ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์ •์˜ ---
template = """
๋‹น์‹ ์€ ์‚ฌ์šฉ์ž์˜ ๊ฟˆ์„ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์„ํ•ด์ฃผ๋Š” ๊ฟˆ ํ•ด๋ชฝ ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์•„๋ž˜ '๊ฒ€์ƒ‰๋œ ๊ฟˆ ํ•ด๋ชฝ ์ •๋ณด'๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๊ด€๋ จ์„ฑ์ด ๋†’์€ ๋‚ด์šฉ์„ ํ•œ ๋ฌธ์žฅ์œผ๋กœ ์š”์•ฝํ•ด์„œ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”
### ๊ฒ€์ƒ‰๋œ ๊ฟˆ ํ•ด๋ชฝ ์ •๋ณด:
{context}
### ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ:
{question}
### ์ „๋ฌธ๊ฐ€์˜ ๋‹ต๋ณ€:
"""
prompt = PromptTemplate.from_template(template)
# --- 4. RAG ์ฒด์ธ(Chain) ๊ตฌ์„ฑ ---
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# --- ์‘๋‹ต ์ฒญ์†Œ ํ•จ์ˆ˜ ์ถ”๊ฐ€ (์›๋ณธ์—์„œ ๊ฐ€์ ธ์˜ด) ---
def clean_response(response):
# "###"๋กœ ๋ถ„ํ• ํ•˜๊ณ  ์ฒซ ๋ฒˆ์งธ ๋ถ€๋ถ„๋งŒ ๋ฐ˜ํ™˜ (๋ถˆํ•„์š”ํ•œ ํ›„์† ๋‚ด์šฉ ์ œ๊ฑฐ)
cleaned = response.split("###")[0].strip()
return cleaned
# --- 5. ๊ฟˆ ํ•ด๋ชฝ ํ•จ์ˆ˜ ์ •์˜ (Gradio์—์„œ ํ˜ธ์ถœ๋  ํ•จ์ˆ˜) ---
def interpret_dream(query):
docs_with_scores = vectorstore.similarity_search_with_score(query, k=2)
score_threshold = 1.2
filtered_docs = [doc for doc, score in docs_with_scores if score < score_threshold]
if not filtered_docs:
return "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ํ˜„์žฌ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ํ•ด๋‹น ๊ฟˆ์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ ๋ถ€์กฑํ•˜์—ฌ ํ•ด๋ชฝ์„ ์ œ๊ณตํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค."
else:
context_texts = [f"- {d.page_content}: {d.metadata['meaning']}" for d in filtered_docs]
local_rag_chain = (
prompt
| llm
| StrOutputParser()
)
response = local_rag_chain.invoke({
"context": "\n".join(context_texts),
"question": query
})
try:
final_answer = response.split("### ์ „๋ฌธ๊ฐ€์˜ ๋‹ต๋ณ€:")[1].strip()
except IndexError:
final_answer = response.strip()
# ์ฒญ์†Œ ํ•จ์ˆ˜ ์ ์šฉ (์›๋ณธ์—์„œ ๊ฐ€์ ธ์˜ด)
final_answer = clean_response(final_answer)
return final_answer
# --- 6. Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ---
with gr.Blocks() as demo:
gr.Markdown("# โœจ ๊ฟˆ ํ•ด๋ชฝ ์„œ๋น„์Šค")
gr.Markdown("๋‹น์‹ ์˜ ๊ฟˆ ๋‚ด์šฉ์„ ์ž…๋ ฅํ•˜์„ธ์š”. AI๊ฐ€ ๋ช…ํ™•ํ•˜๊ฒŒ ํ•ด์„ํ•ด ๋“œ๋ฆฝ๋‹ˆ๋‹ค!")
with gr.Row():
input_text = gr.Textbox(label="๋‹น์‹ ์˜ ๊ฟˆ์€ ๋ฌด์—‡์ด์—ˆ๋‚˜์š”?", placeholder="์˜ˆ: ํŒŒ๋ž€ ๋ถˆ์ด ๋‚˜๋Š” ๊ฟˆ")
output_text = gr.Textbox(label="๊ฟˆ ํ•ด๋ชฝ ๊ฒฐ๊ณผ", interactive=False)
submit_button = gr.Button("ํ•ด๋ชฝํ•˜๊ธฐ")
submit_button.click(interpret_dream, inputs=input_text, outputs=output_text)
# ์•ฑ ์‹คํ–‰
if __name__ == '__main__':
demo.launch(share=True) # share=True๋กœ ๊ณต๊ณต URL ์ƒ์„ฑ (์„ ํƒ์‚ฌํ•ญ)