Spaces:
Runtime error
Runtime error
File size: 6,125 Bytes
96063e3 c19bf8e 96063e3 f055350 96063e3 c19bf8e 96063e3 f055350 96063e3 fd98557 4603d42 e2c3880 4603d42 e2c3880 d51c335 fd98557 96063e3 aeab5b9 96063e3 c22b0fa 96063e3 c22b0fa 96063e3 1fd4148 96063e3 4603d42 c22b0fa 375e3cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
from datasets import load_dataset,DatasetDict,Dataset
# from datasets import
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight
import torch
import pandas as pd
from fudge.model import Model
import os
from argparse import ArgumentParser
from collections import namedtuple
import mock
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from fudge.data import Dataset
from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from fudge.constants import *
device = 'cpu'
# imp.reload(model)
pretrained_model = "checkpoint-150/"
generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
pad_id = 0
generation_model.eval()
model_args = mock.Mock()
model_args.task = 'clickbait'
model_args.device = device
model_args.checkpoint = 'checkpoint-1464/'
# conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
conditioning_model = conditioning_model.to(device)
conditioning_model.eval()
condition_lambda = 5.0
length_cutoff = 50
precondition_topk = 200
conditioning_model.classifier
model_args.checkpoint
classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
def rate_title(input_text, model, tokenizer, device='cuda'):
# input_text = {
# "postText": input_text['postText'],
# "truthClass" : input_text['truthClass']
# }
tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
# print(tokenized_input.items())
dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
predicted_class = float(model(**dict_tokenized_input).logits)
actual_class = input_text['truthClass']
# print(predicted_class, actual_class)
return {'predicted_class' : predicted_class}
def preprocess_function_title_only_classification(examples,tokenizer=None):
model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
model_inputs['labels'] = examples['truthClass']
return model_inputs
input_example = "On Friday, a clip of Los Angeles Lakers star LeBron James from the latest episode of \"The Shop: Uninterrupted\" is going viral on Twitter. \"Cause they racist as f--k,\" James said when asked why he hates Boston. James has had many battles with the Boston Celtics in the NBA Playoffs. According to StatMuse, he has played the Celtics 41 times in the NBA Playoffs. He's played them in the playoffs when he was on the Cleveland Cavaliers (the first time), the Miami Heat and the Cavs (the second time). Therefore, he has had quite the experience facing off with them in hostile environments. He is 25-16 against them in the 41 playoff games and averaged 29.6 points per game. (also according to StatMuse). James is currently on the Los Angeles Lakers, and the team missed the postseason this past year. They were the 11th seed in the Western Conference, so they also missed the play-in tournament which was a big surprise. His first year in Los Angeles, they also missed the playoffs, but the following season he led them to their first NBA Championship since the 2010 season. In 2021, they lost in the first-round, so they have been on a downward trajectory since winning the title. Next season will be his 20th season in the NBA, and he is widely regarded as one of the top-five (and to some the greatest) player ever to play the game of basketball. He is 37-years-old, and was the first overall pick out of high school in the 2003 NBA Draft. "
output_example = "Here's why Lebron James hates the Celtics"
textbox_input = gr.Textbox(label = "Article content",
value=input_example)
textbox_output = gr.Textbox(label = "Output clickbait title",
value=output_example)
def clickbait_generator(article_content, condition_lambda=5.0):
results = generate_clickbait(model=generation_model,
tokenizer=tokenizer,
conditioning_model=conditioning_model,
input_text=[None],
dataset_info=None,
precondition_topk=precondition_topk,
length_cutoff=length_cutoff,
condition_lambda=condition_lambda,
article_content=article_content,
device=device)
return results[0].replace('</s>', '').replace('<pad>', '')
title = "Clickbaitinator - Controllable Clickbait generator"
description = """
Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine-tuned for our purposes to try and create news headline you are looking for! Use condition_lambda to steer your clickbaitiness higher (by increasing the slider value) or lower (by decreasing the slider value). <br/>
Note that this is using two Transformers and is executed with CPU-only, so it will take a minute or two to finish generating a title.
"""
article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/clickbaitonator) that this demo is based of. You need collaborator access, which you have been probably invited for."
app = gr.Interface(
title = title,
description = description,
label = 'Article content or paragraph',
fn = clickbait_generator,
inputs=[textbox_input, gr.Slider(0, 15, step=0.1, value=5.0)],
outputs=textbox_output,
article=article,
)
app.launch() |