|
|
import os, io, base64, json |
|
|
from datetime import datetime |
|
|
import streamlit as st |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
from db import init_db, get_user_by_username, create_conversation, list_conversations, rename_conversation, delete_conversation, add_message, get_messages, list_users, set_user_role, set_user_active, update_user_password |
|
|
from auth import hash_password, verify_password, ensure_admin |
|
|
from providers import OpenAIProvider, OllamaProvider, ProviderError |
|
|
|
|
|
st.set_page_config(page_title="ChatAI (Streamlit)", page_icon="💬", layout="wide") |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
<style> |
|
|
.stChatMessage .stMarkdown { |
|
|
border-radius: 16px; |
|
|
padding: 12px 14px; |
|
|
background: rgba(240, 242, 246, 0.6); |
|
|
} |
|
|
.stChatMessage[data-testid="stChatMessageUser"] .stMarkdown { |
|
|
background: rgba(147, 197, 253, 0.25); |
|
|
} |
|
|
.plus-fab { |
|
|
position: fixed; |
|
|
bottom: 24px; |
|
|
right: 24px; |
|
|
width: 52px; |
|
|
height: 52px; |
|
|
border-radius: 50%; |
|
|
background: #4F46E5; |
|
|
color: white; |
|
|
border: none; |
|
|
box-shadow: 0 8px 24px rgba(0,0,0,0.15); |
|
|
font-size: 28px; |
|
|
cursor: pointer; |
|
|
z-index: 9999; |
|
|
} |
|
|
.upload-card { |
|
|
position: fixed; |
|
|
bottom: 88px; |
|
|
right: 24px; |
|
|
width: 320px; |
|
|
background: white; |
|
|
border-radius: 16px; |
|
|
box-shadow: 0 10px 30px rgba(0,0,0,0.15); |
|
|
padding: 16px; |
|
|
z-index: 9999; |
|
|
border: 1px solid rgba(0,0,0,0.06); |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
st.markdown(custom_css, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
init_db() |
|
|
created_admin, admin_pwd = ensure_admin() |
|
|
|
|
|
|
|
|
if "user" not in st.session_state: |
|
|
st.session_state.user = None |
|
|
if "conversation_id" not in st.session_state: |
|
|
st.session_state.conversation_id = None |
|
|
if "show_uploader" not in st.session_state: |
|
|
st.session_state.show_uploader = False |
|
|
if "messages_cache" not in st.session_state: |
|
|
st.session_state.messages_cache = [] |
|
|
|
|
|
def do_login(username, password): |
|
|
user = get_user_by_username(username) |
|
|
if not user or not user["is_active"]: |
|
|
st.error("Sai tài khoản hoặc tài khoản đang bị khóa.") |
|
|
return False |
|
|
if verify_password(password, user["password_hash"]): |
|
|
st.session_state.user = {"id": user["id"], "username": user["username"], "role": user["role"]} |
|
|
return True |
|
|
st.error("Mật khẩu không đúng.") |
|
|
return False |
|
|
|
|
|
def logout(): |
|
|
st.session_state.user = None |
|
|
st.session_state.conversation_id = None |
|
|
st.session_state.messages_cache = [] |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("⚙️ Cấu hình") |
|
|
|
|
|
if st.session_state.user: |
|
|
st.success(f"Đã đăng nhập: **{st.session_state.user['username']}** ({st.session_state.user['role']})") |
|
|
if st.button("Đăng xuất"): |
|
|
logout() |
|
|
st.rerun() |
|
|
else: |
|
|
st.subheader("Đăng nhập") |
|
|
with st.form("login_form", clear_on_submit=False): |
|
|
u = st.text_input("Tên đăng nhập") |
|
|
p = st.text_input("Mật khẩu", type="password") |
|
|
submitted = st.form_submit_button("Đăng nhập") |
|
|
if submitted: |
|
|
if do_login(u, p): |
|
|
st.rerun() |
|
|
|
|
|
st.divider() |
|
|
st.subheader("Nhà cung cấp AI") |
|
|
provider = st.selectbox("Provider", ["OpenAI", "Ollama"]) |
|
|
if provider == "OpenAI": |
|
|
openai_key = st.text_input("OpenAI API Key", type="password", value=os.environ.get("OPENAI_API_KEY", "")) |
|
|
model = st.text_input("Model", value="gpt-4o-mini") |
|
|
else: |
|
|
ollama_url = st.text_input("Ollama Endpoint", value="http://localhost:11434") |
|
|
model = st.text_input("Model", value="llama3.1:8b") |
|
|
temperature = st.slider("Temperature", 0.0, 1.0, 0.3, 0.05) |
|
|
|
|
|
st.divider() |
|
|
st.subheader("Tùy chọn") |
|
|
sys_prompt = st.text_area("System Prompt (tùy chọn)", value="You are a helpful assistant. Answer in Vietnamese if the user speaks Vietnamese.") |
|
|
if created_admin: |
|
|
st.info(f"Admin mặc định đã được tạo. Tài khoản: admin / Mật khẩu: {admin_pwd} — hãy đổi ngay!") |
|
|
|
|
|
|
|
|
def page_chat(): |
|
|
st.title("💬 ChatAI") |
|
|
|
|
|
if not st.session_state.user: |
|
|
st.info("Hãy đăng nhập để bắt đầu trò chuyện.") |
|
|
return |
|
|
|
|
|
|
|
|
left, right = st.columns([1, 3]) |
|
|
with left: |
|
|
st.subheader("🗂 Cuộc trò chuyện") |
|
|
if st.button("➕ Tạo cuộc trò chuyện mới"): |
|
|
st.session_state.conversation_id = create_conversation(st.session_state.user["id"], title="New Chat") |
|
|
st.session_state.messages_cache = [] |
|
|
st.rerun() |
|
|
convs = list_conversations(st.session_state.user["id"]) |
|
|
for c in convs: |
|
|
selected = st.button(f"🗨 {c['title']}", key=f"conv_{c['id']}") |
|
|
if selected: |
|
|
st.session_state.conversation_id = c["id"] |
|
|
st.session_state.messages_cache = get_messages(c["id"]) |
|
|
st.rerun() |
|
|
|
|
|
with right: |
|
|
if not st.session_state.conversation_id: |
|
|
st.info("Chưa có cuộc trò chuyện. Hãy tạo mới bên trái.") |
|
|
return |
|
|
|
|
|
|
|
|
cc1, cc2 = st.columns([3,1]) |
|
|
with cc1: |
|
|
new_title = st.text_input("Tên cuộc trò chuyện", value="") |
|
|
if st.button("Đổi tên"): |
|
|
if new_title.strip(): |
|
|
rename_conversation(st.session_state.conversation_id, new_title.strip()) |
|
|
st.success("Đã đổi tên.") |
|
|
else: |
|
|
st.warning("Tên không hợp lệ.") |
|
|
with cc2: |
|
|
if st.button("🗑 Xóa cuộc trò chuyện", type="secondary"): |
|
|
delete_conversation(st.session_state.conversation_id) |
|
|
st.session_state.conversation_id = None |
|
|
st.session_state.messages_cache = [] |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
msgs = st.session_state.messages_cache or get_messages(st.session_state.conversation_id) |
|
|
for m in msgs: |
|
|
role = m["role"] |
|
|
content = m["content"] |
|
|
with st.chat_message("assistant" if role=="assistant" else "user"): |
|
|
st.markdown(content) |
|
|
try: |
|
|
atts = json.loads(m.get("attachments") or "[]") |
|
|
for a in atts: |
|
|
st.caption(f"📎 {a.get('name','file')} ({a.get('type','file')})") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
user_msg = st.chat_input("Nhập tin nhắn...") |
|
|
|
|
|
|
|
|
st.markdown('<button class="plus-fab" onclick="window.parent.postMessage({type:\\'toggle_uploader\\'}, \\\"*\\\")">+</button>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if st.button("Hiện/Tắt upload (fallback)"): |
|
|
st.session_state.show_uploader = not st.session_state.show_uploader |
|
|
|
|
|
if st.session_state.show_uploader: |
|
|
with st.container(): |
|
|
st.markdown('<div class="upload-card">', unsafe_allow_html=True) |
|
|
st.write("**Tải lên để đính kèm**") |
|
|
file_uploader = st.file_uploader("Tệp (txt, pdf)", type=["txt","pdf"], accept_multiple_files=True, key="file_up") |
|
|
img_uploader = st.file_uploader("Ảnh", type=["png","jpg","jpeg","webp"], accept_multiple_files=True, key="img_up") |
|
|
if st.button("Đóng"): |
|
|
st.session_state.show_uploader = False |
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if user_msg or (st.session_state.get("file_up") or st.session_state.get("img_up")): |
|
|
attachments = [] |
|
|
|
|
|
def extract_text_from_file(f): |
|
|
name = f.name |
|
|
if name.lower().endswith(".txt"): |
|
|
return f.read().decode("utf-8", errors="ignore") |
|
|
if name.lower().endswith(".pdf"): |
|
|
try: |
|
|
import PyPDF2 |
|
|
reader = PyPDF2.PdfReader(io.BytesIO(f.read())) |
|
|
pages = [] |
|
|
for p in reader.pages: |
|
|
pages.append(p.extract_text() or "") |
|
|
return "\\n".join(pages) |
|
|
except Exception as e: |
|
|
return f"[Không thể trích xuất PDF: {e}]" |
|
|
return "" |
|
|
|
|
|
uploaded_files = st.session_state.get("file_up") or [] |
|
|
uploaded_imgs = st.session_state.get("img_up") or [] |
|
|
|
|
|
context_snippets = [] |
|
|
for f in uploaded_files: |
|
|
text = extract_text_from_file(f) |
|
|
attachments.append({"name": f.name, "type": "file", "size": f.size}) |
|
|
if text: |
|
|
context_snippets.append(f"### {f.name}\\n{text[:6000]}") |
|
|
|
|
|
image_refs = [] |
|
|
for img in uploaded_imgs: |
|
|
b64 = base64.b64encode(img.read()).decode("utf-8") |
|
|
mime = "image/png" if img.name.lower().endswith("png") else "image/jpeg" |
|
|
image_refs.append({"name": img.name, "type": "image", "b64": b64, "mime": mime}) |
|
|
attachments.append({"name": img.name, "type": "image", "size": img.size}) |
|
|
|
|
|
db_msgs = get_messages(st.session_state.conversation_id) |
|
|
chat_history = [{"role": m["role"], "content": m["content"]} for m in db_msgs] |
|
|
|
|
|
system_preamble = sys_prompt.strip() if sys_prompt else "" |
|
|
if context_snippets: |
|
|
system_preamble += "\\n\\n# File context (tóm tắt)\\n" + "\\n\\n".join(context_snippets) |
|
|
|
|
|
messages_for_provider: List[Dict[str, Any]] = [] |
|
|
if system_preamble: |
|
|
messages_for_provider.append({"role": "system", "content": system_preamble}) |
|
|
|
|
|
messages_for_provider.extend(chat_history[-12:]) |
|
|
if user_msg: |
|
|
messages_for_provider.append({"role": "user", "content": user_msg}) |
|
|
|
|
|
try: |
|
|
if provider == "OpenAI": |
|
|
p = OpenAIProvider(api_key=openai_key if openai_key else None) |
|
|
resp = p.generate(messages=messages_for_provider, model=model, temperature=temperature) |
|
|
else: |
|
|
p = OllamaProvider(base_url=ollama_url) |
|
|
resp = p.generate(messages=messages_for_provider, model=model, temperature=temperature) |
|
|
except ProviderError as e: |
|
|
resp = f"Lỗi nhà cung cấp: {e}" |
|
|
except Exception as e: |
|
|
resp = f"Lỗi không xác định: {e}" |
|
|
|
|
|
if user_msg: |
|
|
add_message(st.session_state.conversation_id, "user", user_msg, attachments=json.dumps(attachments)) |
|
|
st.session_state.messages_cache.append({"role":"user","content":user_msg,"attachments":json.dumps(attachments)}) |
|
|
add_message(st.session_state.conversation_id, "assistant", resp) |
|
|
st.session_state.messages_cache.append({"role":"assistant","content":resp,"attachments":"[]"}) |
|
|
st.session_state.show_uploader = False |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
st.divider() |
|
|
if st.button("⬇️ Xuất cuộc trò chuyện (Markdown)"): |
|
|
msgs = get_messages(st.session_state.conversation_id) |
|
|
md = ["# Lịch sử trò chuyện"] |
|
|
for m in msgs: |
|
|
who = "👤 User" if m["role"]=="user" else "🤖 Assistant" |
|
|
md.append(f"**{who}**\\n\\n{m['content']}\\n") |
|
|
b = "\\n\\n---\\n\\n".join(md).encode("utf-8") |
|
|
st.download_button("Tải về .md", data=b, file_name="chat_history.md", mime="text/markdown") |
|
|
|
|
|
def page_admin(): |
|
|
if not st.session_state.user or st.session_state.user["role"] != "admin": |
|
|
st.warning("Chỉ Admin mới truy cập được trang này.") |
|
|
return |
|
|
|
|
|
st.title("🛠 Admin Panel") |
|
|
st.subheader("Quản lý người dùng") |
|
|
|
|
|
users = list_users() |
|
|
for u in users: |
|
|
cols = st.columns([2,1,1,1,2]) |
|
|
cols[0].write(f"**{u['username']}**") |
|
|
cols[1].write(u["role"]) |
|
|
cols[2].write("✅" if u["is_active"] else "🚫") |
|
|
if cols[3].button("Đổi vai trò", key=f"role_{u['id']}"): |
|
|
new_role = "admin" if u["role"]=="user" else "user" |
|
|
set_user_role(u["id"], new_role) |
|
|
st.rerun() |
|
|
with cols[4]: |
|
|
c1, c2, c3 = st.columns(3) |
|
|
if c1.button("Khóa/Mở", key=f"toggle_{u['id']}"): |
|
|
set_user_active(u["id"], not u["is_active"]) |
|
|
st.rerun() |
|
|
if c2.button("Đặt lại MK", key=f"reset_{u['id']}"): |
|
|
newpw = st.text_input(f"Mật khẩu mới cho {u['username']}", key=f"pw_{u['id']}") |
|
|
if newpw: |
|
|
update_user_password(u["id"], hash_password(newpw)) |
|
|
st.success("Đã cập nhật mật khẩu.") |
|
|
st.rerun() |
|
|
if c3.button("Xóa", key=f"del_{u['id']}"): |
|
|
from db import delete_user as delu |
|
|
if u["username"]=="admin": |
|
|
st.error("Không được xóa tài khoản admin gốc.") |
|
|
else: |
|
|
delu(u["id"]) |
|
|
st.rerun() |
|
|
|
|
|
st.divider() |
|
|
st.subheader("Thêm người dùng mới") |
|
|
with st.form("new_user_form"): |
|
|
nu = st.text_input("Tên đăng nhập") |
|
|
np = st.text_input("Mật khẩu", type="password") |
|
|
role = st.selectbox("Vai trò", ["user", "admin"]) |
|
|
submitted = st.form_submit_button("Tạo") |
|
|
if submitted: |
|
|
if not nu or not np: |
|
|
st.error("Thiếu thông tin.") |
|
|
elif get_user_by_username(nu): |
|
|
st.error("Tên đăng nhập đã tồn tại.") |
|
|
else: |
|
|
from db import create_user |
|
|
create_user(nu, hash_password(np), role=role, is_active=True) |
|
|
st.success("Đã tạo người dùng.") |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
tab = st.tabs(["💬 Chat", "🛠 Admin"]) |
|
|
with tab[0]: |
|
|
page_chat() |
|
|
with tab[1]: |
|
|
page_admin() |
|
|
|
|
|
|
|
|
st.markdown(\""" |
|
|
<script> |
|
|
window.addEventListener('message', (event) => { |
|
|
if (event.data && event.data.type === 'toggle_uploader') { |
|
|
const btns = window.parent.document.querySelectorAll('button'); |
|
|
if (btns && btns.length > 0) { btns[btns.length-1].click(); } |
|
|
} |
|
|
}, false); |
|
|
</script> |
|
|
\""", unsafe_allow_html=True) |
|
|
|