randomshit11 commited on
Commit
c82ab3f
1 Parent(s): 9268a45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -2
app.py CHANGED
@@ -1,3 +1,67 @@
1
- import gradio as gr
 
 
2
 
3
- gr.Interface.load("models/randomshit11/fin-bert-1st-shit").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ def chunk_text(text, chunk_size):
6
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
7
+ return chunks
8
+
9
+ def shorten_text(text, min_length, max_length):
10
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
11
+ chunks = chunk_text(text, 1024)
12
+ summary_chunks = []
13
+ for chunk in chunks:
14
+ summary = summarizer(chunk, max_length, min_length, do_sample=False)
15
+ summary_chunks.append(summary[0]["summary_text"])
16
+ summary = ' '.join(summary_chunks)
17
+ return summary
18
+
19
+ def paraphrase_text(text, min_length, max_length):
20
+ tokenizer = AutoTokenizer.from_pretrained("randomshit11/fin-bert-1st-shit")
21
+ model = AutoModelForSeq2SeqLM.from_pretrained("randomshit11/fin-bert-1st-shit")
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ text_instruction = "Summary: " + text + " </s>"
24
+ chunks = chunk_text(text_instruction, 1024)
25
+ output_chunks = []
26
+ for chunk in chunks:
27
+ encoding = tokenizer.encode_plus(chunk, padding="longest", return_tensors="pt")
28
+ input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
29
+ outputs = model.generate(
30
+ input_ids=input_ids, attention_mask=attention_masks,
31
+ max_length=max_length,
32
+ do_sample=True,
33
+ top_k=120,
34
+ top_p=0.95,
35
+ early_stopping=True,
36
+ num_return_sequences=5
37
+ )
38
+ line = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
39
+ output_chunks.append(line)
40
+ output = ' '.join(output_chunks)
41
+ return output
42
+
43
+ def modify_text(mode, text, min_length, max_length):
44
+ if mode == "shorten":
45
+ return shorten_text(text, min_length, max_length)
46
+ else:
47
+ return paraphrase_text(text, min_length, max_length)
48
+
49
+ gradio_interface = gradio.Interface(
50
+ fn=modify_text,
51
+ inputs=[
52
+ gradio.Radio(["shorten", "Summary"], label="Mode"),
53
+ "text",
54
+ gradio.Slider(5, 200, value=30, label="Min length"),
55
+ gradio.Slider(5, 500, value=130, label="Max length")
56
+ ],
57
+ outputs="text",
58
+ examples=[
59
+ ["shorten",
60
+ """Your long input text goes here...""",
61
+ 30, 130]
62
+ ],
63
+ title="Text shortener/paraphraser",
64
+ description="Shortening texts using `facebook/bart-large-cnn`, paraphrasing texts using `fin-bert-1st-shit`.",
65
+ article="© Tom Söderlund 2022"
66
+ )
67
+ gradio_interface.launch()