File size: 1,837 Bytes
13f3081
 
 
 
9823d38
13f3081
 
 
 
9823d38
13f3081
 
 
d3f947f
13f3081
d60334d
13f3081
d60334d
13f3081
 
 
 
 
 
 
 
d60334d
 
d3f947f
 
13f3081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
baseline_interactive.py
"""
import gradio as gr
from transformers import MBartForConditionalGeneration, MBartTokenizer
from transformers import pipeline

model_name = "momo/rsp-sum"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="ko_KR", tgt_lang="ko_KR")

# prefix = "translate English to German: "

def summarization(News, Summary):
    summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
    summarizer(News, min_length=50, max_length=150)

    for result in summarizer(News):
        print(result)
    return result

if __name__ == '__main__':

    #Create a gradio app with a button that calls predict()
    app = gr.Interface(
        fn=summarization,
        inputs=gr.inputs.Textbox(lines=10, label="News"),
        outputs=gr.outputs.Textbox(lines=10, label="Summary"), 
        title="한국어 뉴스 요약 생성기",
        description="Korean News Summary Generator"
        )
    app.launch()

# with torch.no_grad():
#     while True:
#         t = input("\nDocument: ")
#         tokens = tokenizer(
#             t,
#             return_tensors="pt",
#             truncation=True,
#             padding=True,
#             max_length=600
#         )

#         input_ids = tokens.input_ids.cuda()
#         attention_mask = tokens.attention_mask.cuda()

#         sample_output = model.generate(
#             input_ids, 
#             max_length=150, 
#             num_beams=5, 
#             early_stopping=True,
#             no_repeat_ngram_size=8,
#     )
#         # print("token:" + str(input_ids.detach().cpu()))
#         # print("token:" + tokenizer.convert_ids_to_tokens(str(input_ids.detach().cpu())))
#         print("Summary: " + tokenizer.decode(sample_output[0], skip_special_tokens=True))