|
import pandas as pd |
|
import numpy as np |
|
|
|
from string import whitespace, punctuation |
|
import re |
|
import unicodedata |
|
from sentence_transformers import SentenceTransformer, util |
|
import gradio as gr |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
|
|
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration |
|
from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast |
|
from transformers.optimization import get_cosine_schedule_with_warmup |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
|
|
|
|
def CleanEnd(text): |
|
email = re.compile( |
|
r'[-_0-9a-z]+@[-_0-9a-z]+(?:\.[0-9a-z]+)+', flags=re.IGNORECASE) |
|
url = re.compile( |
|
r'(?:https?:\/\/)?[-_0-9a-z]+(?:\.[-_0-9a-z]+)+', flags=re.IGNORECASE) |
|
etc = re.compile( |
|
r'\.([^\.]*(?:๊ธฐ์|ํนํ์|๊ต์|์๊ฐ|๋ํ|๋
ผ์ค|๊ณ ๋ฌธ|์ฃผํ|๋ถ๋ฌธ์ฅ|ํ์ฅ|์ฅ๊ด|์์ฅ|์ฐ๊ตฌ์|์ด์ฌ์ฅ|์์|์ค์ฅ|์ฐจ์ฅ|๋ถ์ฅ|์์ธ์ด|ํ๋ฐฑ|์ฌ์ค|์์ฅ|๋จ์ฅ|๊ณผ์ฅ|๊ธฐํ์|ํ๋ ์ดํฐ|์ ์๊ถ|ํ๋ก ๊ฐ|ยฉ|ยฉ|โ|\@|\/|=|โถ|๋ฌด๋จ|์ ์ฌ|์ฌ๋ฐฐํฌ|๊ธ์ง|\[|\]|\(\))[^\.]*)$') |
|
bracket = re.compile(r'^((?:\[.+\])|(?:ใ.+ใ)|(?:<.+>)|(?:โ.+โ)\s)') |
|
|
|
result = email.sub('', text) |
|
result = url.sub('', result) |
|
result = etc.sub('.', result) |
|
result = bracket.sub('', result).strip() |
|
return result |
|
|
|
|
|
def TextFilter(text): |
|
punct = ''.join([chr for chr in punctuation if chr != '%']) |
|
filtering = re.compile(f'[{whitespace}{punct}]+') |
|
onlyText = re.compile(r'[^\% ใฑ-ใ
ฃ๊ฐ-ํฃ]+') |
|
result = filtering.sub(' ', text) |
|
result = onlyText.sub(' ', result).strip() |
|
result = filtering.sub(' ', result) |
|
return result |
|
|
|
|
|
def is_clickbait(title, content, threshold=0.815): |
|
model = SentenceTransformer( |
|
'./model/onlineContrastive') |
|
|
|
pattern_whitespace = re.compile(f'[{whitespace}]+') |
|
title = unicodedata.normalize('NFC', re.sub( |
|
pattern_whitespace, ' ', title)).strip() |
|
title = CleanEnd(title) |
|
title = TextFilter(title) |
|
|
|
content = unicodedata.normalize('NFC', re.sub( |
|
pattern_whitespace, ' ', content)).strip() |
|
content = CleanEnd(content) |
|
content = TextFilter(content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings1 = model.encode(title, convert_to_tensor=True) |
|
embeddings2 = model.encode(content, convert_to_tensor=True) |
|
|
|
|
|
cosine_score = util.cos_sim(embeddings1, embeddings2) |
|
similarity = cosine_score.numpy()[0][0] |
|
|
|
if similarity < threshold: |
|
return 0, similarity |
|
else: |
|
return 1, similarity |
|
|
|
|
|
|
|
|
|
df_train = pd.DataFrame() |
|
df_train['input_text'] = ['1', '2'] |
|
df_train['target_text'] = ['1', '2'] |
|
|
|
|
|
def CleanEnd_g(text): |
|
email = re.compile( |
|
r'[-_0-9a-z]+@[-_0-9a-z]+(?:\.[0-9a-z]+)+', flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
|
result = email.sub('', text) |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
class DatasetFromDataframe(Dataset): |
|
def __init__(self, df, dataset_args): |
|
self.data = df |
|
self.max_length = dataset_args['max_length'] |
|
self.tokenizer = dataset_args['tokenizer'] |
|
self.start_token = '<s>' |
|
self.end_token = '</s>' |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def create_tokens(self, text): |
|
tokens = self.tokenizer.encode( |
|
self.start_token + text + self.end_token) |
|
|
|
tokenLength = len(tokens) |
|
remain = self.max_length - tokenLength |
|
|
|
if remain >= 0: |
|
tokens = tokens + [self.tokenizer.pad_token_id] * remain |
|
attention_mask = [1] * tokenLength + [0] * remain |
|
else: |
|
tokens = tokens[: self.max_length - 1] + \ |
|
self.tokenizer.encode(self.end_token) |
|
attention_mask = [1] * self.max_length |
|
|
|
return tokens, attention_mask |
|
|
|
def __getitem__(self, index): |
|
record = self.data.iloc[index] |
|
|
|
question, answer = record['input_text'], record['target_text'] |
|
|
|
input_id, input_mask = self.create_tokens(question) |
|
output_id, output_mask = self.create_tokens(answer) |
|
|
|
label = output_id[1:(self.max_length + 1)] |
|
label = label + (self.max_length - len(label)) * [-100] |
|
|
|
return { |
|
'input_ids': torch.LongTensor(input_id), |
|
'attention_mask': torch.LongTensor(input_mask), |
|
'decoder_input_ids': torch.LongTensor(output_id), |
|
'decoder_attention_mask': torch.LongTensor(output_mask), |
|
"labels": torch.LongTensor(label) |
|
} |
|
|
|
|
|
class OneSourceDataModule(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
**kwargs |
|
): |
|
super().__init__() |
|
|
|
self.data = kwargs.get('data') |
|
self.dataset_args = kwargs.get("dataset_args") |
|
self.batch_size = kwargs.get("batch_size") or 32 |
|
self.train_size = kwargs.get("train_size") or 0.9 |
|
|
|
def setup(self, stage=""): |
|
|
|
self.trainset = DatasetFromDataframe(df_train, self.dataset_args) |
|
self.testset = DatasetFromDataframe(df_train, self.dataset_args) |
|
|
|
def train_dataloader(self): |
|
train = DataLoader( |
|
self.trainset, |
|
batch_size=self.batch_size |
|
) |
|
return train |
|
|
|
def val_dataloader(self): |
|
val = DataLoader( |
|
self.testset, |
|
batch_size=self.batch_size |
|
) |
|
return val |
|
|
|
def test_dataloader(self): |
|
test = DataLoader( |
|
self.testset, |
|
batch_size=self.batch_size |
|
) |
|
return test |
|
|
|
|
|
class KoBARTConditionalGeneration(pl.LightningModule): |
|
def __init__(self, hparams, **kwargs): |
|
super(KoBARTConditionalGeneration, self).__init__() |
|
self.hparams.update(hparams) |
|
|
|
self.model = kwargs['model'] |
|
self.tokenizer = kwargs['tokenizer'] |
|
|
|
self.model.train() |
|
|
|
def configure_optimizers(self): |
|
param_optimizer = list(self.model.named_parameters()) |
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
|
|
optimizer_grouped_parameters = [{ |
|
'params': [ |
|
p for n, p in param_optimizer if not any(nd in n for nd in no_decay) |
|
], |
|
'weight_decay': 0.01 |
|
}, { |
|
'params': [ |
|
p for n, p in param_optimizer if any(nd in n for nd in no_decay) |
|
], |
|
'weight_decay': 0.0 |
|
}] |
|
|
|
optimizer = torch.optim.AdamW( |
|
optimizer_grouped_parameters, |
|
lr=self.hparams.lr |
|
) |
|
|
|
|
|
data_len = len(self.train_dataloader().dataset) |
|
print(f'ํ์ต ๋ฐ์ดํฐ ์: {data_len}') |
|
|
|
num_train_steps = int( |
|
data_len / self.hparams.batch_size * self.hparams.max_epochs) |
|
print(f'Step ์: {num_train_steps}') |
|
|
|
num_warmup_steps = int(num_train_steps * self.hparams.warmup_ratio) |
|
print(f'Warmup Step ์: {num_warmup_steps}') |
|
|
|
scheduler = get_cosine_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_train_steps |
|
) |
|
|
|
lr_scheduler = { |
|
'scheduler': scheduler, |
|
'monitor': 'loss', |
|
'interval': 'step', |
|
'frequency': 1 |
|
} |
|
|
|
return [optimizer], [lr_scheduler] |
|
|
|
def forward(self, inputs): |
|
return self.model( |
|
input_ids=inputs['input_ids'], |
|
attention_mask=inputs['attention_mask'], |
|
decoder_input_ids=inputs['decoder_input_ids'], |
|
decoder_attention_mask=inputs['decoder_attention_mask'], |
|
labels=inputs['labels'], |
|
return_dict=True |
|
) |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss = self(batch).loss |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss = self(batch).loss |
|
|
|
def test(self, text): |
|
tokens = self.tokenizer.encode("<s>" + text + "</s>") |
|
|
|
tokenLength = len(tokens) |
|
remain = self.hparams.max_length - tokenLength |
|
|
|
if remain >= 0: |
|
tokens = tokens + [self.tokenizer.pad_token_id] * remain |
|
attention_mask = [1] * tokenLength + [0] * remain |
|
else: |
|
tokens = tokens[: self.hparams.max_length - 1] + \ |
|
self.tokenizer.encode("</s>") |
|
attention_mask = [1] * self.hparams.max_length |
|
|
|
tokens = torch.LongTensor([tokens]) |
|
attention_mask = torch.LongTensor([attention_mask]) |
|
self.model = self.model |
|
|
|
result = self.model.generate( |
|
tokens, |
|
max_length=self.hparams.max_length, |
|
attention_mask=attention_mask, |
|
num_beams=10 |
|
)[0] |
|
|
|
a = self.tokenizer.decode(result) |
|
return a |
|
|
|
|
|
def generation(szContent): |
|
tokenizer = PreTrainedTokenizerFast.from_pretrained( |
|
"gogamza/kobart-summarization") |
|
model1 = BartForConditionalGeneration.from_pretrained( |
|
"gogamza/kobart-summarization") |
|
if len(szContent) > 500: |
|
input_ids = tokenizer.encode(szContent[:500], return_tensors="pt") |
|
else: |
|
input_ids = tokenizer.encode(szContent, return_tensors="pt") |
|
|
|
summary = model1.generate( |
|
input_ids=input_ids, |
|
bos_token_id=model1.config.bos_token_id, |
|
eos_token_id=model1.config.eos_token_id, |
|
length_penalty=.3, |
|
max_length=35, |
|
min_length=25, |
|
num_beams=5) |
|
szSummary = tokenizer.decode(summary[0], skip_special_tokens=True) |
|
print(szSummary) |
|
KoBARTModel = BartForConditionalGeneration.from_pretrained( |
|
'./model/final2.h5') |
|
BATCH_SIZE = 32 |
|
MAX_LENGTH = 128 |
|
EPOCHS = 0 |
|
model2 = KoBARTConditionalGeneration({ |
|
"lr": 5e-6, |
|
"warmup_ratio": 0.1, |
|
"batch_size": BATCH_SIZE, |
|
"max_length": MAX_LENGTH, |
|
"max_epochs": EPOCHS |
|
}, |
|
tokenizer=tokenizer, |
|
model=KoBARTModel |
|
) |
|
dm = OneSourceDataModule( |
|
data=df_train, |
|
batch_size=BATCH_SIZE, |
|
train_size=0.9, |
|
dataset_args={ |
|
"tokenizer": tokenizer, |
|
"max_length": MAX_LENGTH, |
|
} |
|
) |
|
trainer = pl.Trainer( |
|
max_epochs=EPOCHS, |
|
gpus=0 |
|
) |
|
|
|
trainer.fit(model2, dm) |
|
szTitle = model2.test(szSummary) |
|
df = pd.DataFrame() |
|
df['newTitle'] = [szTitle] |
|
df['content'] = [szContent] |
|
|
|
pattern_whitespace = re.compile(f'[{whitespace}]+') |
|
df['newTitle'] = df.newTitle.fillna('').replace(pattern_whitespace, ' ').map( |
|
lambda x: unicodedata.normalize('NFC', x)).str.strip() |
|
df['newTitle'] = df.newTitle.map(CleanEnd_g) |
|
df['newTitle'] = df.newTitle.map(TextFilter) |
|
return df.newTitle[0] |
|
|
|
|
|
def new_headline(title, content): |
|
label = is_clickbait(title, content) |
|
if label[0] == 0: |
|
return generation(content) |
|
elif label[0] == 1: |
|
return '๋์์ฑ ๊ธฐ์ฌ๊ฐ ์๋๋๋ค.' |
|
|
|
|
|
|
|
with gr.Blocks() as demo1: |
|
gr.Markdown( |
|
""" |
|
<h1 align="center"> |
|
clickbait news classifier and new headline generator |
|
</h1> |
|
""") |
|
|
|
gr.Markdown( |
|
""" |
|
๋ด์ค ๊ธฐ์ฌ ์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ์
๋ ฅํ๋ฉด ๋์์ฑ ๊ธฐ์ฌ์ธ์ง ๋ถ๋ฅํ๊ณ , |
|
๋์์ฑ ๊ธฐ์ฌ์ด๋ฉด ์๋ก์ด ์ ๋ชฉ์ ์์ฑํด์ฃผ๋ ํ๋ก๊ทธ๋จ์
๋๋ค. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
inputs = [gr.Textbox(placeholder="๋ด์ค๊ธฐ์ฌ ์ ๋ชฉ์ ์
๋ ฅํด์ฃผ์ธ์", label='headline'), |
|
gr.Textbox( |
|
lines=10, placeholder="๋ด์ค๊ธฐ์ฌ ๋ณธ๋ฌธ์ ์
๋ ฅํด์ฃผ์ธ์", label='content')] |
|
with gr.Row(): |
|
btn = gr.Button("๊ฒฐ๊ณผ ์ถ๋ ฅ") |
|
with gr.Column(): |
|
output = gr.Text(label='Result') |
|
btn.click(fn=new_headline, inputs=inputs, outputs=output) |
|
|
|
if __name__ == "__main__": |
|
demo1.launch() |
|
|