Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_HOME"] = "/home/user/.cache/huggingface" | |
from langchain_community.llms import HuggingFacePipeline | |
import torch | |
# beomi/gemma-ko-2b ๋ชจ๋ธ ์ง์ | |
model_id = "beomi/gemma-ko-2b" | |
llm = HuggingFacePipeline.from_model_id( | |
model_id=model_id, | |
task="text-generation", | |
device=None, # Explicitly set to None to avoid conflicts with device_map | |
model_kwargs={"torch_dtype": torch.bfloat16, "device_map": "auto"} | |
) | |
# %% | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = '/data/heesu/huggingface_cache' | |
import torch | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
import gradio as gr # Gradio ์ํฌํธ ์ถ๊ฐ | |
# --- 1. Vector DB ๋ฐ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ --- | |
print("Vector DB์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๋ ์ค์ ๋๋ค...") | |
index_path = "proj2_voca" | |
model_name = "jhgan/ko-sroberta-multitask" | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': True} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs | |
) | |
vectorstore = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True) | |
retriever = vectorstore.as_retriever(search_kwargs={'k': 2}) | |
# --- 2. ์ธ์ด ๋ชจ๋ธ(LLM) ๋ถ๋ฌ์ค๊ธฐ --- | |
print("์ธ์ด ๋ชจ๋ธ(beomi/gemma-ko-2b)์ ๋ถ๋ฌ์ค๋ ์ค์ ๋๋ค. ์๊ฐ์ด ๋ค์ ๊ฑธ๋ฆด ์ ์์ต๋๋ค...") | |
model_id = "beomi/gemma-ko-2b" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map=None # ์๋ ๋ฐฐ์น ๋นํ์ฑํ | |
).to("cpu") # GPU ๋ฒํธ ๋ช ์ | |
# --- 2. ์ธ์ด ๋ชจ๋ธ(LLM) ๋ถ๋ฌ์ค๊ธฐ --- | |
# ... (์ด์ ์ฝ๋ ์๋ต) ... | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1, | |
max_new_tokens=170, # ์ต๋ ์์ฑ ํ ํฐ ์ ๊ฐ์ | |
temperature=0.7, | |
repetition_penalty=1.2, # ๋ฐ๋ณต ๋ฐฉ์ง ํจ๋ํฐ ์ถ๊ฐ | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# ... (์ดํ ์ฝ๋ ์๋ต) ... | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# --- 3. ํ๋กฌํํธ ํ ํ๋ฆฟ ์ ์ --- | |
template = """ | |
๋น์ ์ ์ฌ์ฉ์์ ๊ฟ์ ๋ช ํํ๊ฒ ํด์ํด์ฃผ๋ ๊ฟ ํด๋ชฝ ์ ๋ฌธ๊ฐ์ ๋๋ค. ์๋ '๊ฒ์๋ ๊ฟ ํด๋ชฝ ์ ๋ณด'๋ฅผ ๋ฐํ์ผ๋ก, ์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํด ๊ฐ์ฅ ๊ด๋ จ์ฑ์ด ๋์ ๋ด์ฉ์ ํ ๋ฌธ์ฅ์ผ๋ก ์์ฝํด์ ๋ต๋ณํด์ฃผ์ธ์ | |
### ๊ฒ์๋ ๊ฟ ํด๋ชฝ ์ ๋ณด: | |
{context} | |
### ์ฌ์ฉ์์ ์ง๋ฌธ: | |
{question} | |
### ์ ๋ฌธ๊ฐ์ ๋ต๋ณ: | |
""" | |
prompt = PromptTemplate.from_template(template) | |
# --- 4. RAG ์ฒด์ธ(Chain) ๊ตฌ์ฑ --- | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
# --- ์๋ต ์ฒญ์ ํจ์ ์ถ๊ฐ (์๋ณธ์์ ๊ฐ์ ธ์ด) --- | |
def clean_response(response): | |
# "###"๋ก ๋ถํ ํ๊ณ ์ฒซ ๋ฒ์งธ ๋ถ๋ถ๋ง ๋ฐํ (๋ถํ์ํ ํ์ ๋ด์ฉ ์ ๊ฑฐ) | |
cleaned = response.split("###")[0].strip() | |
return cleaned | |
# --- 5. ๊ฟ ํด๋ชฝ ํจ์ ์ ์ (Gradio์์ ํธ์ถ๋ ํจ์) --- | |
def interpret_dream(query): | |
docs_with_scores = vectorstore.similarity_search_with_score(query, k=2) | |
score_threshold = 1.2 | |
filtered_docs = [doc for doc, score in docs_with_scores if score < score_threshold] | |
if not filtered_docs: | |
return "์ฃ์กํฉ๋๋ค. ํ์ฌ ๋ฐ์ดํฐ๋ฒ ์ด์ค์ ํด๋น ๊ฟ์ ๋ํ ์ ๋ณด๊ฐ ๋ถ์กฑํ์ฌ ํด๋ชฝ์ ์ ๊ณตํ๊ธฐ ์ด๋ ต์ต๋๋ค." | |
else: | |
context_texts = [f"- {d.page_content}: {d.metadata['meaning']}" for d in filtered_docs] | |
local_rag_chain = ( | |
prompt | |
| llm | |
| StrOutputParser() | |
) | |
response = local_rag_chain.invoke({ | |
"context": "\n".join(context_texts), | |
"question": query | |
}) | |
try: | |
final_answer = response.split("### ์ ๋ฌธ๊ฐ์ ๋ต๋ณ:")[1].strip() | |
except IndexError: | |
final_answer = response.strip() | |
# ์ฒญ์ ํจ์ ์ ์ฉ (์๋ณธ์์ ๊ฐ์ ธ์ด) | |
final_answer = clean_response(final_answer) | |
return final_answer | |
# --- 6. Gradio ์ธํฐํ์ด์ค ๊ตฌ์ฑ --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# โจ ๊ฟ ํด๋ชฝ ์๋น์ค") | |
gr.Markdown("๋น์ ์ ๊ฟ ๋ด์ฉ์ ์ ๋ ฅํ์ธ์. AI๊ฐ ๋ช ํํ๊ฒ ํด์ํด ๋๋ฆฝ๋๋ค!") | |
with gr.Row(): | |
input_text = gr.Textbox(label="๋น์ ์ ๊ฟ์ ๋ฌด์์ด์๋์?", placeholder="์: ํ๋ ๋ถ์ด ๋๋ ๊ฟ") | |
output_text = gr.Textbox(label="๊ฟ ํด๋ชฝ ๊ฒฐ๊ณผ", interactive=False) | |
submit_button = gr.Button("ํด๋ชฝํ๊ธฐ") | |
submit_button.click(interpret_dream, inputs=input_text, outputs=output_text) | |
# ์ฑ ์คํ | |
if __name__ == '__main__': | |
demo.launch(share=True) # share=True๋ก ๊ณต๊ณต URL ์์ฑ (์ ํ์ฌํญ) | |