|
--- |
|
license: mit |
|
--- |
|
|
|
[Korean BART](https://huggingface.co/hyunwoongko/kobart) model for finetuning task. |
|
The dataset utilized can be found on the *Files and versions* tab under the name dataset.csv. |
|
|
|
```python |
|
import torch |
|
from transformers import BartForConditionalGeneration, AutoTokenizer |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = BartForConditionalGeneration.from_pretrained('guialfaro/korean-paraphrasing').to(device) |
|
tokenizer = AutoTokenizer.from_pretrained('guialfaro/korean-paraphrasing') |
|
|
|
sentence = "7층 방문을 위해 방문록 작성이 필요합니다." |
|
text = f"paraphrase: {sentence} " |
|
|
|
encoding = tokenizer.batch_encode_plus( |
|
[text], |
|
max_length=256, |
|
pad_to_max_length=True, |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="pt",) |
|
|
|
source_ids = encoding["input_ids"].to(device, dtype=torch.long) |
|
source_mask = encoding["attention_mask"].to(device, dtype=torch.long) |
|
|
|
generated_ids = model.generate( |
|
input_ids=source_ids, |
|
attention_mask=source_mask, |
|
max_length=150, |
|
num_beams=2, |
|
repetition_penalty=2.5, |
|
length_penalty=1.0, |
|
early_stopping=True) |
|
|
|
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] |
|
|
|
print(f"Original Sentence :: {sentence}") |
|
print(f"Paraphrased Sentence :: {preds[0]}") |
|
|
|
``` |