File size: 2,537 Bytes
c82ab3f
 
 
9268a45
c82ab3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

def chunk_text(text, chunk_size):
    chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
    return chunks

def shorten_text(text, min_length, max_length):
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    chunks = chunk_text(text, 1024)
    summary_chunks = []
    for chunk in chunks:
        summary = summarizer(chunk, max_length, min_length, do_sample=False)
        summary_chunks.append(summary[0]["summary_text"])
    summary = ' '.join(summary_chunks)
    return summary

def paraphrase_text(text, min_length, max_length):
    tokenizer = AutoTokenizer.from_pretrained("randomshit11/fin-bert-1st-shit")  
    model = AutoModelForSeq2SeqLM.from_pretrained("randomshit11/fin-bert-1st-shit")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    text_instruction = "Summary: " + text + " </s>"
    chunks = chunk_text(text_instruction, 1024)
    output_chunks = []
    for chunk in chunks:
        encoding = tokenizer.encode_plus(chunk, padding="longest", return_tensors="pt")
        input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
        outputs = model.generate(
            input_ids=input_ids, attention_mask=attention_masks,
            max_length=max_length,
            do_sample=True,
            top_k=120,
            top_p=0.95,
            early_stopping=True,
            num_return_sequences=5
        )
        line = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        output_chunks.append(line)
    output = ' '.join(output_chunks)
    return output

def modify_text(mode, text, min_length, max_length):
    if mode == "shorten":
        return shorten_text(text, min_length, max_length)
    else:
        return paraphrase_text(text, min_length, max_length)

gradio_interface = gradio.Interface(
    fn=modify_text,
    inputs=[
        gradio.Radio(["shorten", "Summary"], label="Mode"),
        "text",
        gradio.Slider(5, 200, value=30, label="Min length"),
        gradio.Slider(5, 500, value=130, label="Max length")
    ],
    outputs="text",
    examples=[
        ["shorten",
         """Your long input text goes here...""",
         30, 130]
    ],
    title="Text shortener/paraphraser",
    description="Shortening texts using `facebook/bart-large-cnn`, paraphrasing texts using `fin-bert-1st-shit`.",
)
gradio_interface.launch()