Dusan Svilarkovic commited on
Commit
375e3cd
1 Parent(s): ab00779

Adding running script

Browse files
Files changed (1) hide show
  1. script.py +115 -0
script.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from predict_clickbait import generate_clickbait
3
+ from datasets import load_dataset,DatasetDict,Dataset
4
+ # from datasets import
5
+ from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
6
+ import numpy as np
7
+ from sklearn.model_selection import train_test_split
8
+ import pandas as pd
9
+ from sklearn.utils.class_weight import compute_class_weight
10
+ import torch
11
+ import pandas as pd
12
+ from model import Model
13
+ import imp
14
+ import os
15
+ import random
16
+ import time
17
+ import pickle
18
+ import math
19
+ from argparse import ArgumentParser
20
+ from collections import namedtuple
21
+ import mock
22
+
23
+ from tqdm import tqdm
24
+ import numpy as np
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from data import Dataset
28
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
29
+ from constants import *
30
+ from predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
31
+
32
+
33
+ # imp.reload(model)
34
+ pretrained_model = "./checkpoint-150/"
35
+ generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
36
+
37
+ device = 'cuda'
38
+ pad_id = 0
39
+
40
+ generation_model.eval()
41
+
42
+ model_args = mock.Mock()
43
+ model_args.task = 'clickbait'
44
+ model_args.device = device
45
+ model_args.checkpoint = './checkpoint-1464/'
46
+
47
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
48
+ conditioning_model = conditioning_model.to(device)
49
+ conditioning_model.eval()
50
+
51
+ condition_lambda = 5.0
52
+ length_cutoff = 50
53
+ precondition_topk = 200
54
+
55
+
56
+ conditioning_model.classifier
57
+
58
+ model_args.checkpoint
59
+
60
+ classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
61
+
62
+
63
+ def rate_title(input_text, model, tokenizer, device='cuda'):
64
+ # input_text = {
65
+ # "postText": input_text['postText'],
66
+ # "truthClass" : input_text['truthClass']
67
+ # }
68
+ tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
69
+ # print(tokenized_input.items())
70
+ dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
71
+ predicted_class = float(model(**dict_tokenized_input).logits)
72
+ actual_class = input_text['truthClass']
73
+
74
+ # print(predicted_class, actual_class)
75
+ return {'predicted_class' : predicted_class}
76
+
77
+ def preprocess_function_title_only_classification(examples,tokenizer=None):
78
+ model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
79
+
80
+ model_inputs['labels'] = examples['truthClass']
81
+
82
+ return model_inputs
83
+
84
+
85
+
86
+ def clickbait_generator(article_content, condition_lambda=5.0):
87
+ # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
88
+ results = generate_clickbait(model=generation_model,
89
+ tokenizer=tokenizer,
90
+ conditioning_model=conditioning_model,
91
+ input_text=[None],
92
+ dataset_info=dataset_info,
93
+ precondition_topk=precondition_topk,
94
+ length_cutoff=length_cutoff,
95
+ condition_lambda=condition_lambda,
96
+ article_content=article_content,
97
+ device=device)
98
+
99
+ return results[0].replace('</s>', '').replace('<pad>', '')
100
+
101
+ title = "Clickbait generator"
102
+ description = """
103
+ "Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine tuned for our purposes to try and create news headline you are looking for!"
104
+ """
105
+
106
+ article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based off of."
107
+
108
+
109
+ app = gr.Interface(
110
+ title = title,
111
+ description = description,
112
+ label = 'Article content or paragraph',
113
+ fn = clickbait_generator,
114
+ inputs=["text", gr.Slider(0, 100, step=0.1, value=5.0)], outputs="text")
115
+ app.launch()