# import os # os.chdir('naacl-2021-fudge-controlled-generation/') import gradio as gr from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer from datasets import load_dataset,DatasetDict,Dataset # from datasets import from transformers import AutoTokenizer,AutoModelForSeq2SeqLM import numpy as np from sklearn.model_selection import train_test_split import pandas as pd from sklearn.utils.class_weight import compute_class_weight import torch import pandas as pd from fudge.model import Model import os from argparse import ArgumentParser from collections import namedtuple import mock from tqdm import tqdm import numpy as np import torch.nn as nn import torch.nn.functional as F from fudge.data import Dataset from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params from fudge.constants import * # imp.reload(model) pretrained_model = "checkpoint-150/" generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device) device = 'cuda' pad_id = 0 generation_model.eval() model_args = mock.Mock() model_args.task = 'clickbait' model_args.device = device model_args.checkpoint = 'checkpoint-1464/' # conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway conditioning_model = conditioning_model.to(device) conditioning_model.eval() condition_lambda = 5.0 length_cutoff = 50 precondition_topk = 200 conditioning_model.classifier model_args.checkpoint classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True) def rate_title(input_text, model, tokenizer, device='cuda'): # input_text = { # "postText": input_text['postText'], # "truthClass" : input_text['truthClass'] # } tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer) # print(tokenized_input.items()) dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'} predicted_class = float(model(**dict_tokenized_input).logits) actual_class = input_text['truthClass'] # print(predicted_class, actual_class) return {'predicted_class' : predicted_class} def preprocess_function_title_only_classification(examples,tokenizer=None): model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25) model_inputs['labels'] = examples['truthClass'] return model_inputs def clickbait_generator(article_content, condition_lambda=5.0): # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2)) results = generate_clickbait(model=generation_model, tokenizer=tokenizer, conditioning_model=conditioning_model, input_text=[None], dataset_info=dataset_info, precondition_topk=precondition_topk, length_cutoff=length_cutoff, condition_lambda=condition_lambda, article_content=article_content, device=device) return results[0].replace('', '').replace('', '') title = "Clickbait generator" description = """ "Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!" """ article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of." app = gr.Interface( title = title, description = description, label = 'Article content or paragraph', fn = clickbait_generator, inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text") app.launch()