korean-paraphrasing / README.md
guialfaro's picture
Update README.md
4ce55a4
|
raw
history blame
1.5 kB
metadata
license: mit

Korean BART model for finetuning task. The dataset utilized can be found on the Files and versions tab under the name dataset.csv.

import torch
from transformers import BartForConditionalGeneration, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BartForConditionalGeneration.from_pretrained('guialfaro/korean-paraphrasing').to(device)
tokenizer = AutoTokenizer.from_pretrained('guialfaro/korean-paraphrasing')

sentence = "7층 방문을 위해 방문록 작성이 필요합니다."
text =  f"paraphrase: {sentence} "

encoding = tokenizer.batch_encode_plus(
            [text],
            max_length=256,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",)

source_ids = encoding["input_ids"].to(device, dtype=torch.long)
source_mask = encoding["attention_mask"].to(device, dtype=torch.long)

generated_ids = model.generate(
                input_ids=source_ids,
                attention_mask=source_mask,
                max_length=150,
                num_beams=2,
                repetition_penalty=2.5,
                length_penalty=1.0,
                early_stopping=True)

preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]

print(f"Original Sentence :: {sentence}")
print(f"Paraphrased Sentence :: {preds[0]}")