import streamlit as st import pandas as pd from pathlib import Path import requests import base64 from requests.auth import HTTPBasicAuth import torch st.set_page_config(layout="wide") @st.cache(allow_output_mutation=True) def load_model(): from transformers import ( EncoderDecoderModel, AutoTokenizer, ) PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern" tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) model = EncoderDecoderModel.from_pretrained(PRETRAINED) return tokenizer, model tokenizer, model = load_model() def inference(text): print(f"from: {text}") tk_kwargs = dict( truncation=True, max_length=168, padding="max_length", return_tensors='pt') inputs = tokenizer([text, ], **tk_kwargs) with torch.no_grad(): new = tokenizer.batch_decode( model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, num_beams=3, max_length=256, bos_token_id=101, eos_token_id=tokenizer.sep_token_id, pad_token_id=tokenizer.pad_token_id, ), skip_special_tokens=True)[0].replace(" ", "") print(f"to: {new}") return new @st.cache def get_file_df(): file_df = pd.read_csv("meta.csv") return file_df file_df = get_file_df() st.sidebar.title("【隨無涯】") st.sidebar.markdown(""" * 朕自庖[🤗 模型](https://huggingface.co/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 訓習處](https://github.com/raynardj/yuan) * 📚 充棟汗牛,取自[殆知閣](http://www.daizhige.org/),[github api](https://github.com/garychowcmu/daizhigev20) """) c2 = st.container() c2.write("The entirety of ancient Chinese literature, with a modern translator at your side.") st.markdown("""---""") c = st.container() USER_ID = st.secrets["USER_ID"] SECRET = st.secrets["SECRET"] @st.cache def get_maps(): file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values) file_size_map = dict(file_df[["filepath", "fsize"]].values) return file_obj_hash_map, file_size_map file_obj_hash_map, file_size_map = get_maps() def show_file_size(size: int): if size < 1024: return f"{size} B" elif size < 1024*1024: return f"{size//1024} KB" else: return f"{size/1024//1024} MB" @st.cache(max_entries=100, allow_output_mutation=True) def fetch_file(path): # reading from local path first if (Path("data")/path).exists(): with open(Path("data")/path, "r") as f: return f.read() # read from github api obj_hash = file_obj_hash_map[path] auth = HTTPBasicAuth(USER_ID, SECRET) url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}" print(f"requesting {url}") r = requests.get(url, auth=auth) if r.status_code == 200: data = r.json() content = base64.b64decode(data['content']).decode('utf-8') return content else: r.raise_for_status() @st.cache(allow_output_mutation=True, max_entries=100) def fetch_from_df(sub_paths: str = ""): sub_df = file_df.copy() for idx, step in enumerate(sub_paths): sub_df.query(f"col_{idx} == '{step}'", inplace=True) if len(sub_df) == 0: return None return list(sub_df[f"col_{len(sub_paths)}"].unique()) def show_filepath(filepath: str): text = fetch_file(filepath) c.markdown( f"""
{text}
""", unsafe_allow_html=True) if st.sidebar.selectbox(label="何以尋跡 How to search",options=["以類尋書 category","書名求書 search"])=="以類尋書 category": # root_data = fetch_from_github() if 'pathway' in st.session_state: pass else: st.session_state.pathway = [] path_text = st.sidebar.text("/".join(st.session_state.pathway)) def reset_path(): st.session_state.pathway = [] path_text.text(st.session_state.pathway) if st.sidebar.button("還至初錄(back to root)"): reset_path() def display_tree(): sublist = fetch_from_df(st.session_state.pathway) dropdown = st.sidebar.selectbox("【擇書 choose】", options=sublist) with st.spinner("書非借不能讀也..."): st.session_state.pathway.append(dropdown) if dropdown.endswith('.txt'): filepath = "/".join(st.session_state.pathway) file_size = file_size_map[filepath] with st.spinner(f"Load 載文:{filepath},({show_file_size(file_size)})"): # if file size is too large, we will not load it if file_size > 3*1024*1024: print(f"skip {filepath}") urlpath = filepath.replace(".txt", ".html") dzg = f"http://www.daizhige.org/{urlpath}" st.markdown(f"File too big 其文碩而難載,不能為之,[往 殆知閣]({dzg}), 或擇他書") reset_path() return None path_text.text(filepath) print(f"read {filepath}") text = fetch_file(filepath) # create markdown with max heights c.markdown( f"""
{text}
""", unsafe_allow_html=True ) reset_path() else: sub_list = fetch_from_df( st.session_state.pathway) path_text.text("/".join(st.session_state.pathway)) display_tree() display_tree() else: def search_kw(): result = file_df[file_df.filepath.str.contains(st.session_state.kw)].reset_index(drop=True) if len(result) == 0: st.sidebar.write(f"尋之不得:{st.session_state.kw}") else: filepath = st.sidebar.selectbox("選一書名", options=list(result.head(15).filepath)) show_filepath(filepath) def loading_with_search(): kw = st.sidebar.text_input("書名求書 Search", value="楞伽经") st.session_state.kw = kw search_kw() loading_with_search() def translate_text(): if c2.button("【曉文達義 Translate】"): if cc: if len(cc) > 168: c2.write(f"句甚長 不得過百又六十八字 Sentence too long, should be less than 168 characters") else: c2.markdown(f"""```{inference(cc)}```""") else: c2.write("【入難曉之文字 Please input sentence for translating】") cc = c2.text_area("【入難曉之文字 Input sentence】", height=150) translate_text()