File size: 1,409 Bytes
9366986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st

@st.cache(allow_output_mutation=True)
def get_pipe():
    from transformers import AutoTokenizer, AutoModelForCausalLM
    model_name = "heegyu/koalpaca-355m"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.truncation_side = "right"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def get_response(tokenizer, model, context):
    context = f"<usr>{context}\n<sys>"
    inputs = tokenizer(
        context, 
        truncation=True,
        max_length=512,
        return_tensors="pt")
    
    generation_args = dict(
        max_length=256,
        min_length=64,
        eos_token_id=2,
        do_sample=True,
        top_p=1.0,
        early_stopping=True
    )

    outputs = model.generate(**inputs, **generation_args)
    response = tokenizer.decode(outputs[0])
    print(context)
    print(response)
    response = response[len(context):].replace("</s>", "")

    return response

st.title("KoAlpaca-355M")

with st.spinner("loading model..."):
    model, tokenizer = get_pipe()

input_ = st.text_area("질문해보세요", value="미국과 중국의 갈등의 원인이 뭐야?")
ok = st.button("물어보기")
if input_ is not None and ok and len(input_) > 0:
    with st.spinner("잠시만요"):
        response = get_response(tokenizer, model, input_)
        st.text("대답")
        st.success(response)