# 패키지 불러오기
import re
import streamlit as st
from transformers import pipeline
# 글로벌 변수 선언
## 본 코드에서 사용된 신문기사 요약용 모델(https://huggingface.co/gangyeolkim/kobart-korean-summarizer-v2)은 테스트용 입니다.
pipe = pipeline("summarization", model="gangyeolkim/kobart-korean-summarizer-v2")
## assitant 아바타 이미지 경로
assistant_icon_path = "https://huggingface.co/spaces/randmimc/aitom/resolve/main/chat_icon/assistant.ico"
## 세션 상태에서 approval_state 및 messages 초기화
if "approval_state" not in st.session_state:
st.session_state.approval_state = "require"
if "messages" not in st.session_state:
st.session_state.messages = []
st.markdown("""
""", unsafe_allow_html=True)
def user_message_style(question):
return f"""
"""
def assistant_message_style(assistant_icon_path, answer):
return f"""
{answer}
"""
for message in st.session_state.messages:
if message["role"] == "user":
st.markdown(user_message_style(message["content"]), unsafe_allow_html=True)
else:
st.markdown(assistant_message_style(assistant_icon_path, message["content"]), unsafe_allow_html=True)
if not st.session_state.messages:
greeting = '신문기사 요약 모델 입니다. 사용을 위해서는 "yes"를 입력해 주세요'
st.session_state.messages.append({"role": "assistant", "content": greeting})
st.markdown(assistant_message_style(assistant_icon_path, greeting), unsafe_allow_html=True)
if prompt := st.chat_input():
# 사용자 질문 세션에 저장
st.session_state.messages.append({"role": "user", "content": prompt})
# 사용자 질문 화면에 표시
st.markdown(user_message_style(prompt), unsafe_allow_html=True)
# 인증
if st.session_state.approval_state == "require":
if prompt.lower() == "yes":
st.session_state.approval_state = "approved"
response = "이제 사용이 가능합니다. 단, 모델 크기가 작은 모델이여서 성능이 기대에 못미칠 수 있습니다."
else:
response = "인증에 실패하였습니다. 다시 yes를 입력해 주세요"
# 응답 생성
else:
result = re.sub(r'\s+', ' ', prompt)
summarized = pipe(result)
response = summarized[0]["summary_text"]
# 어시스턴스의 메세지를 세션에 저장
st.session_state.messages.append({"role": "assistant", "content": response})
# 어시스턴스의 메세지를 화면에 표시
st.markdown(assistant_message_style(assistant_icon_path, response), unsafe_allow_html=True)