File size: 2,248 Bytes
a528501
 
8817b20
 
 
 
 
 
 
 
 
 
a528501
 
 
4b98af4
a528501
4b98af4
a528501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8817b20
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch

import streamlit as st

st.title("Title Generation with Transformers")
st.write("")
st.write("Input your text here!")


default_value = "Ukrainian counterattacks: Kharkiv's regional administrator said a number of villages around Malaya Rogan were retaken by Ukrainian forces. Video verified by CNN shows Ukrainian troops in control of Vilkhivka, one of the settlements roughly 20 miles from the Russian border. The success of Ukrainian forces around Kharkiv has been mirrored further north, near the city of Sumy, where Ukrainian troops have liberated a number of settlements, according to videos geolocated and verified by CNN. A separate counterattack in the south also led to the liberation of two villages from Russian forces northwest of Mariupol, according to the Zaporizhzhia regional military administration."

sent = st.text_area("Text", default_value, height = 50)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("deep-learning-analytics/automatic-title-generation")

model = AutoModelForSeq2SeqLM.from_pretrained("deep-learning-analytics/automatic-title-generation")
  

def tokenize_data(text):
    # Tokenize the review body
    input_ =  str(text) + ' </s>'
    max_len = 120
    # tokenize inputs
    tokenized_inputs = tokenizer(input_, padding='max_length', truncation=True, max_length=max_len, return_attention_mask=True, return_tensors='pt')

    inputs={"input_ids": tokenized_inputs['input_ids'],
        "attention_mask": tokenized_inputs['attention_mask']}
    return inputs

def generate_answers(text):
    inputs = tokenize_data(text)
    results= model.generate(input_ids= inputs['input_ids'], attention_mask=inputs['attention_mask'], do_sample=True,
                            max_length=120,
                            top_k=120,
                            top_p=0.98,
                            early_stopping=True,
                            num_return_sequences=1)
    answer = tokenizer.decode(results[0], skip_special_tokens=True)
    return answer

answer = generate_answers(sent)

st.write(answer)

#iface = gr.Interface(fn=generate_answers,inputs=[gr.inputs.Textbox(lines=20)], outputs=["text"])
#iface.launch(inline=False, share=True)