File size: 1,716 Bytes
5e7dc6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47ca0aa
5e7dc6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4725d98
 
 
 
5e7dc6f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr

# make function using import pip to install torch
import pip
pip.main(['install', 'torch'])
pip.main(['install', 'transformers'])

import torch
import transformers


# saved_model
def load_model(model_path):
    saved_data = torch.load(
        model_path,
        map_location="cpu"
    )

    bart_best = saved_data["model"]
    train_config = saved_data["config"]
    tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')

    ## Load weights.
    model = transformers.BartForConditionalGeneration.from_pretrained('gogamza/kobart-base-v1')
    model.load_state_dict(bart_best)

    return model, tokenizer


# main
def inference(prompt):
    model_path = "./kobart-model-poem.pth"

    model, tokenizer = load_model(
        model_path=model_path
        )

    input_ids = tokenizer.encode(prompt)
    input_ids = torch.tensor(input_ids)
    input_ids = input_ids.unsqueeze(0)
    output = model.generate(input_ids)
    output = tokenizer.decode(output[0], skip_special_tokens=True)    

    return output


demo = gr.Interface(
    fn=inference, 
    inputs="text", 
    outputs="text", #return κ°’
    examples=[
        "μž‘μ€ 씨앗 ν•˜λ‚˜ 길가에 λ–¨μ–΄μ‘Œμ–΄μš”\nμ•ˆλΌμš” μ•ˆλΌμš” λͺ¨λ‘λ“€ λ°ŸμœΌλ‹ˆκΉŒμš”\nμž‘μ€ 씨앗 ν•˜λ‚˜ λŒλ°­μ— λ–¨μ–΄μ‘Œμ–΄μš”\nμ‹«μ–΄μš” μ‹«μ–΄μš” 크게 μžλž„ 수 μ—†μ–΄μš”\nμž‘μ€ 씨앗 ν•˜λ‚˜ κ°€μ‹œλ°­μ— λ–¨μ–΄μ‘Œμ–΄μš”\nμ•„ μ•Όμ•Ό, μ•„νŒŒμš” μˆ¨μ„ 쉴 μˆ˜κ°€ μ—†μ–΄μš”\nμž‘μ€ 씨앗 ν•˜λ‚˜ 쒋은 밭에 λ–¨μ–΄μ‘Œμ–΄μš”\nμ’‹μ•„μš” μ’‹μ•„μš” 잘 μžλΌμ„œ 쒋은 λ‚˜λ¬΄ λ˜κ² μ–΄μš”"
    ]
    ).launch() # launch(share=True)λ₯Ό μ„€μ •ν•˜λ©΄ μ™ΈλΆ€μ—μ„œ 접속 κ°€λŠ₯ν•œ 링크가 생성됨

demo.launch()