|
import hydra |
|
from hydra.core.global_hydra import GlobalHydra |
|
from omegaconf import DictConfig, OmegaConf |
|
import streamlit as st |
|
from PIL import Image |
|
import os |
|
import sys |
|
sys.path.append(os.path.dirname(__file__)) |
|
import torch |
|
from download_models import download_model |
|
|
|
|
|
@st.cache_resource |
|
def load_simple_rag(config, used_lmdeploy=False): |
|
|
|
data_source_dir = config["data_source_dir"] |
|
db_persist_directory = config["db_persist_directory"] |
|
llm_model = config["llm_model"] |
|
embeddings_model = config["embeddings_model"] |
|
reranker_model = config["reranker_model"] |
|
llm_system_prompt = config["llm_system_prompt"] |
|
rag_prompt_template = config["rag_prompt_template"] |
|
from rag.simple_rag import WuleRAG |
|
|
|
if not used_lmdeploy: |
|
from rag.simple_rag import InternLM, WuleRAG |
|
base_mode = InternLM(model_path=llm_model, llm_system_prompt=llm_system_prompt) |
|
else: |
|
from deploy.lmdeploy_model import LmdeployLM, GenerationConfig |
|
cache_max_entry_count = config.get("cache_max_entry_count", 0.2) |
|
base_mode = LmdeployLM(model_path=llm_model, llm_system_prompt=llm_system_prompt, cache_max_entry_count=cache_max_entry_count) |
|
|
|
|
|
wulewule_rag = WuleRAG(data_source_dir, db_persist_directory, base_mode, embeddings_model, reranker_model, rag_prompt_template) |
|
return wulewule_rag |
|
|
|
GlobalHydra.instance().clear() |
|
@hydra.main(version_base=None, config_path="./configs", config_name="model_cfg") |
|
def main(cfg): |
|
|
|
config_dict = OmegaConf.to_container(cfg, resolve=True) |
|
|
|
|
|
if not os.path.exists(config_dict["llm_model"]): |
|
download_model(llm_model_path =config_dict["llm_model"]) |
|
|
|
if cfg.use_rag: |
|
|
|
wulewule_model = load_simple_rag(config_dict, used_lmdeploy=cfg.use_lmdepoly) |
|
elif ( cfg.use_lmdepoly): |
|
|
|
from deploy.lmdeploy_model import load_turbomind_model, GenerationConfig |
|
wulewule_model = load_turbomind_model(config_dict["llm_model"], config_dict["llm_system_prompt"], config_dict["cache_max_entry_count"]) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
|
|
with st.sidebar: |
|
st.markdown("## 悟了悟了💡") |
|
logo_path = "assets/sd_wulewule.webp" |
|
if os.path.exists(logo_path): |
|
image = Image.open(logo_path) |
|
st.image(image, caption='wulewule') |
|
"[InternLM](https://github.com/InternLM)" |
|
"[悟了悟了](https://github.com/xzyun2011/wulewule.git)" |
|
|
|
|
|
|
|
st.title("悟了悟了:黑神话悟空AI助手🐒") |
|
|
|
|
|
for msg in st.session_state.messages: |
|
st.chat_message("user").write(msg["user"]) |
|
st.chat_message("assistant").write(msg["assistant"]) |
|
|
|
|
|
if prompt := st.chat_input("请输入你的问题,换行使用Shfit+Enter。"): |
|
|
|
st.chat_message("user").write(prompt) |
|
|
|
if cfg.stream_response: |
|
|
|
|
|
full_answer = "" |
|
with st.chat_message('robot'): |
|
message_placeholder = st.empty() |
|
if cfg.use_rag: |
|
for cur_response in wulewule_model.query_stream(prompt): |
|
full_answer += cur_response |
|
|
|
message_placeholder.markdown(full_answer + '▌') |
|
elif cfg.use_lmdepoly: |
|
|
|
|
|
|
|
|
|
|
|
messages = [{'role': 'user', 'content': f'{prompt}'}] |
|
for response in wulewule_model.stream_infer(messages): |
|
full_answer += response.text |
|
|
|
message_placeholder.markdown(full_answer + '▌') |
|
|
|
message_placeholder.markdown(full_answer) |
|
|
|
else: |
|
if cfg.use_lmdepoly: |
|
messages = [{'role': 'user', 'content': f'{prompt}'}] |
|
full_answer = wulewule_model(messages).text |
|
elif cfg.use_rag: |
|
full_answer = wulewule_model.query(prompt) |
|
|
|
st.chat_message("assistant").write(full_answer) |
|
|
|
|
|
st.session_state.messages.append({"user": prompt, "assistant": full_answer}) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |