File size: 1,207 Bytes
f4807cc
 
 
27a7c55
f4807cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr

def get_pipe():
    from transformers import AutoTokenizer, AutoModelForCausalLM
    model_name = "heegyu/koalpaca-355m"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.truncation_side = "right"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def get_response(tokenizer, model, context):
    context = f"<usr>{context}\n<sys>"
    inputs = tokenizer(
        context, 
        truncation=True,
        max_length=512,
        return_tensors="pt")
    
    generation_args = dict(
        max_length=256,
        min_length=64,
        eos_token_id=2,
        do_sample=True,
        top_p=1.0,
        early_stopping=True
    )

    outputs = model.generate(**inputs, **generation_args)
    response = tokenizer.decode(outputs[0])
    print(context)
    print(response)
    response = response[len(context):].replace("</s>", "")

    return response

model, tokenizer = get_pipe()

def ask_question(input_):
    response = get_response(tokenizer, model, input_)
    return response

gr.Interface(fn=ask_question, inputs="text", outputs="text", title="KoAlpaca-355M", description="한국어로 질문하세요.").launch()