hemhemoh commited on
Commit
cc74e3f
1 Parent(s): 78aeedc

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -125
main.py DELETED
@@ -1,125 +0,0 @@
1
- from flask import Flask,render_template,url_for,request, session
2
- from werkzeug.utils import secure_filename
3
- import torch
4
- import pytorch_lightning as pl
5
- from transformers import T5ForConditionalGeneration,T5TokenizerFast as T5Tokenizer, AdamW
6
- from werkzeug.datastructures import FileStorage
7
-
8
-
9
- MODEL_NAME ="t5-base"
10
- tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
11
-
12
- class NewsSummaryModel(pl.LightningModule):
13
- def __init__(self):
14
- super().__init__()
15
- self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
16
-
17
- def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
18
-
19
- output = self.model(
20
- input_ids,
21
- attention_mask = attention_mask,
22
- labels = labels,
23
- decoder_attention_mask = decoder_attention_mask
24
- )
25
- return output.loss, output.logits
26
-
27
- def training_step(self, batch, batch_idx):
28
- input_ids = batch["text_input_ids"]
29
- attention_mask = batch["text_attention_mask"]
30
- labels = batch["labels"]
31
- labels_attention_mask = batch["labels_attention_mask"]
32
-
33
- loss, outputs = self(
34
- input_ids = input_ids,
35
- attention_mask = attention_mask,
36
- decoder_attention_mask = labels_attention_mask,
37
- labels = labels
38
- )
39
-
40
- self.log("train_loss", loss, prog_bar =True, logger=True)
41
- return loss
42
-
43
- def validation_step(self, batch, batch_idx):
44
- input_ids = batch["text_input_ids"]
45
- attention_mask = batch["text_attention_mask"]
46
- labels = batch["labels"]
47
- labels_attention_mask = batch["labels_attention_mask"]
48
-
49
- loss, outputs = self(
50
- input_ids = input_ids,
51
- attention_mask = attention_mask,
52
- decoder_attention_mask = labels_attention_mask,
53
- labels = labels
54
- )
55
-
56
- self.log("val_loss", loss, prog_bar =True, logger=True)
57
- return loss
58
-
59
- def test_step(self, batch, batch_idx):
60
- input_ids = batch["text_input_ids"]
61
- attention_mask = batch["text_attention_mask"]
62
- labels = batch["labels"]
63
- labels_attention_mask = batch["labels_attention_mask"]
64
-
65
- loss, outputs = self(
66
- input_ids = input_ids,
67
- attention_mask = attention_mask,
68
- decoder_attention_mask = labels_attention_mask,
69
- labels = labels
70
- )
71
-
72
- self.log("test_loss", loss, prog_bar =True, logger=True)
73
- return loss
74
-
75
- def configure_optimizers(self):
76
- return AdamW(self.parameters(), lr=0.0001)
77
-
78
-
79
- filename = 'model.pth'
80
- model = torch.load(open(filename, 'rb'))
81
-
82
- def summarize(text):
83
- text_encoding = tokenizer(
84
- text,
85
- max_length=512,
86
- padding = "max_length",
87
- truncation = True,
88
- return_attention_mask = True,
89
- add_special_tokens = True,
90
- return_tensors = "pt"
91
- )
92
- generated_ids = model.model.generate(
93
- input_ids = text_encoding["input_ids"],
94
- attention_mask = text_encoding["attention_mask"],
95
- max_length = 150,
96
- num_beams=2,
97
- repetition_penalty = 2.5,
98
- length_penalty=1.0,
99
- early_stopping = True
100
- )
101
-
102
- preds = [
103
- tokenizer.decode(gen_id, skip_special_tokens = True, clean_up_tokenization_spaces=True)
104
- for gen_id in generated_ids
105
- ]
106
-
107
- return "".join(preds)
108
-
109
-
110
- app = Flask(__name__)
111
-
112
- @app.route('/')
113
- def home():
114
- return render_template('home.html')
115
-
116
- @app.route('/predict',methods=['POST'])
117
- def predict():
118
- if request.method == 'POST':
119
- message = request.form['message']
120
- data = [message]
121
- summary = summarize(data)
122
- return render_template('result.html',Summary=summary)
123
-
124
- if __name__ == '__main__':
125
- app.run()