clickbaitonator / script.py
Dusan Svilarkovic
Fix
706cd84
raw history blame
No virus
4.36 kB
import gradio as gr
from predict_clickbait import generate_clickbait
from datasets import load_dataset,DatasetDict,Dataset
# from datasets import
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight
import torch
import pandas as pd
from model import Model
import imp
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from collections import namedtuple
import mock
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from data import Dataset
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from constants import *
from predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
import os
os.chdir('naacl-2021-fudge-controlled-generation/')
# imp.reload(model)
pretrained_model = "../checkpoint-150/"
generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
device = 'cuda'
pad_id = 0
generation_model.eval()
model_args = mock.Mock()
model_args.task = 'clickbait'
model_args.device = device
model_args.checkpoint = '../checkpoint-1464/'
# conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model = conditioning_model.to(device)
conditioning_model.eval()
condition_lambda = 5.0
length_cutoff = 50
precondition_topk = 200
conditioning_model.classifier
model_args.checkpoint
classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
def rate_title(input_text, model, tokenizer, device='cuda'):
# input_text = {
# "postText": input_text['postText'],
# "truthClass" : input_text['truthClass']
# }
tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
# print(tokenized_input.items())
dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
predicted_class = float(model(**dict_tokenized_input).logits)
actual_class = input_text['truthClass']
# print(predicted_class, actual_class)
return {'predicted_class' : predicted_class}
def preprocess_function_title_only_classification(examples,tokenizer=None):
model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
model_inputs['labels'] = examples['truthClass']
return model_inputs
def clickbait_generator(article_content, condition_lambda=5.0):
# result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
results = generate_clickbait(model=generation_model,
tokenizer=tokenizer,
conditioning_model=conditioning_model,
input_text=[None],
dataset_info=dataset_info,
precondition_topk=precondition_topk,
length_cutoff=length_cutoff,
condition_lambda=condition_lambda,
article_content=article_content,
device=device)
return results[0].replace('</s>', '').replace('<pad>', '')
title = "Clickbait generator"
description = """
"Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!"
"""
article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of."
app = gr.Interface(
title = title,
description = description,
label = 'Article content or paragraph',
fn = clickbait_generator,
inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text")
app.launch()