import streamlit as st @st.cache(allow_output_mutation=True) def get_pipe(): from transformers import AutoTokenizer, AutoModelForCausalLM model_name = "heegyu/koalpaca-355m" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.truncation_side = "right" model = AutoModelForCausalLM.from_pretrained(model_name) return model, tokenizer def get_response(tokenizer, model, context): context = f"{context}\n" inputs = tokenizer( context, truncation=True, max_length=512, return_tensors="pt") generation_args = dict( max_length=256, min_length=64, eos_token_id=2, do_sample=True, top_p=1.0, early_stopping=True ) outputs = model.generate(**inputs, **generation_args) response = tokenizer.decode(outputs[0]) print(context) print(response) response = response[len(context):].replace("", "") return response st.title("KoAlpaca-355M") with st.spinner("loading model..."): model, tokenizer = get_pipe() input_ = st.text_area("질문해보세요", value="미국과 중국의 갈등의 원인이 뭐야?") ok = st.button("물어보기") if input_ is not None and ok and len(input_) > 0: with st.spinner("잠시만요"): response = get_response(tokenizer, model, input_) st.text("대답") st.success(response)