import os import streamlit as st from yasem import SpladeEmbedder if os.getenv("SPACE_ID"): USE_HF_SPACE = True os.environ["HF_HOME"] = "/data/.huggingface" os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" else: USE_HF_SPACE = False MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/japanese-splade-base-v1") @st.cache_resource def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder: embedder = SpladeEmbedder( model_name, ) return embedder def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]: embedder = get_embedder() embeddings = embedder.encode([input_text]) token_values = embedder.get_token_values(embeddings[0]) sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True) # type: ignore return [(value, key) for key, value in sorted_tokens] def main(): st.set_page_config( page_title="SPLADE 日本語 demo", layout="centered", initial_sidebar_state="auto", ) st.title("SPLADE 日本語 demo") get_embedder() st.markdown(""" [hotchpotch/japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1)を使って、テキストからSPLADEのスパースベクトルに変換するデモです。 """) input_text = st.text_area("テキスト", height=200) if st.button("変換"): if input_text.strip(): with st.spinner("変換中..."): sorted_tokens = get_token_values_sorted(input_text) total_tokens = len(sorted_tokens) st.markdown(f"### 結果 (トークン数: {total_tokens})") if sorted_tokens: formatted_data = [ {"スコア": freq, "単語(vocab)": word} for freq, word in sorted_tokens ] st.table(formatted_data) else: st.warning("入力テキストから有効な単語が見つかりませんでした。") else: st.warning("テキストを入力してください。") if __name__ == "__main__": main()