import streamlit as st import pandas as pd from pathlib import Path import requests import base64 from requests.auth import HTTPBasicAuth import torch st.set_page_config(layout="wide") st.title("【随无涯】") @st.cache(allow_output_mutation=True) def load_model(): from transformers import ( EncoderDecoderModel, AutoTokenizer ) PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern" tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) model = EncoderDecoderModel.from_pretrained(PRETRAINED) return tokenizer, model tokenizer, model = load_model() def inference(text): tk_kwargs = dict( truncation=True, max_length=168, padding="max_length", return_tensors='pt') inputs = tokenizer([text, ], **tk_kwargs) with torch.no_grad(): return tokenizer.batch_decode( model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, num_beams=3, max_length=256, bos_token_id=101, eos_token_id=tokenizer.sep_token_id, pad_token_id=tokenizer.pad_token_id, ), skip_special_tokens=True)[0].replace(" ","") @st.cache def get_file_df(): file_df = pd.read_csv("meta.csv") return file_df file_df = get_file_df() col1, col2 = st.columns([.3, 1]) col1.markdown(""" * 朕亲自下厨的[🤗 翻译模型](https://github.com/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 训练笔记](https://github.com/raynardj/yuan) * 📚 书籍来自 [殆知阁](http://www.daizhige.org/),只为了便于展示翻译,喜欢请访问网站,书籍[github文件链接](https://github.com/garychowcmu/daizhigev20) """) USER_ID = st.secrets["USER_ID"] SECRET = st.secrets["SECRET"] @st.cache def get_maps(): file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values) file_size_map = dict(file_df[["filepath", "fsize"]].values) return file_obj_hash_map, file_size_map file_obj_hash_map, file_size_map = get_maps() def show_file_size(size: int): if size < 1024: return f"{size} B" elif size < 1024*1024: return f"{size//1024} KB" else: return f"{size/1024//1024} MB" def fetch_file(path): # reading from local path first if (Path("data")/path).exists(): with open(Path("data")/path, "r") as f: return f.read() # read from github api obj_hash = file_obj_hash_map[path] auth = HTTPBasicAuth(USER_ID, SECRET) url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}" r = requests.get(url, auth=auth) if r.status_code == 200: data = r.json() content = base64.b64decode(data['content']).decode('utf-8') return content else: r.raise_for_status() def fetch_from_df(sub_paths: str = ""): sub_df = file_df.copy() for idx, step in enumerate(sub_paths): sub_df.query(f"col_{idx} == '{step}'", inplace=True) if len(sub_df) == 0: return None return list(sub_df[f"col_{len(sub_paths)}"].unique()) # root_data = fetch_from_github() if 'pathway' in st.session_state: pass else: st.session_state.pathway = [] path_text = col1.text("/".join(st.session_state.pathway)) def reset_path(): print("before rooting") print("/".join(st.session_state.pathway)) st.session_state.pathway = [] path_text.text(st.session_state.pathway) if col1.button("回到根目录"): reset_path() def display_tree(sub_list): dropdown = col1.selectbox("【选书】", options=sub_list) if col1.button(f'【确定{len(st.session_state.pathway)+1}】'): st.session_state.pathway.append(dropdown) if dropdown.endswith('.txt'): filepath = "/".join(st.session_state.pathway) file_size = file_size_map[filepath] col2.write( f"loading file:{filepath},({show_file_size(file_size)})") # if file size is too large, we will not load it if file_size > 3*1024*1024: urlpath = filepath.replace(".txt",".html") dzg = f"http://www.daizhige.org/{urlpath}" st.markdown(f"文件太大,[前往殆知阁页面]({dzg}), 或挑挑其他的书吧") reset_path() return None path_text.text(filepath) text = fetch_file(filepath) # set y scroll markdown col2.markdown(f"""```{text}```""", ) reset_path() else: sub_list = fetch_from_df( st.session_state.pathway) path_text.text("/".join(st.session_state.pathway)) display_tree(sub_list) display_tree(fetch_from_df(st.session_state.pathway)) cc = st.text_area("【输入文本】", height=150) if st.button("【翻译】"): if cc: if len(cc)>168: st.write(f"句子太长,最多168个字符") else: st.markdown(f"""```{inference(cc)}```""") else: st.write("请输入文本")