cswamy commited on
Commit
bd69f73
1 Parent(s): b2e544e

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mt5_amzn_enes_reviews_summarization.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from model import create_mt5_small
5
+
6
+ # Setup model and tokenizer
7
+ model, tokenizer = create_mt5_small()
8
+
9
+ # Load state dict from model
10
+ model.load_state_dict(
11
+ torch.load(
12
+ f="mt5_amzn_enes_reviews_summarization.pth",
13
+ map_location=torch.device("cpu")
14
+ ))
15
+
16
+ # Predict function
17
+ def predict(text:str):
18
+
19
+ # Tokenize inputs and get model outputs
20
+ input = tokenizer(text,
21
+ max_length=512,
22
+ truncation=True,
23
+ return_tensors='pt')
24
+ output_tokens = model.generate(input['input_ids'],
25
+ attention_mask=input['attention_mask'],
26
+ max_length=30)
27
+ output_text = tokenizer.batch_decode(output_tokens,
28
+ skip_special_tokens=True)
29
+
30
+ return output_text
31
+
32
+ # Create examples list
33
+ examples_list = ["The ball hit the splice a lot and sent a fizzing sensation up the handle and into the bottom hand, so I adapted at each session by playing softer and softer, later and later. I found it very difficult to get down the pitch and meet the ball as it landed and so persuaded myself to play back more. It occurred to me that a better player would manage the shimmy down the pitch with more skill and faster footwork, and that the good sweepers would have to take him on in the way that Kevin Pietersen managed so successfully on occasions.",
34
+ "Todo muy bien, cumple con lo esperado. Lo único malo es que: se calienta un poco y la batería no dura 8h. A una persona le ha parecido esto útil"]
35
+
36
+ # Create gradio app
37
+ title = "Summarizer for English and Spanish inputs"
38
+ description = "MT5small model finetuned for summarization on English or Spanish text trained on the Amazon reviews dataset."
39
+
40
+ demo = gr.Interface(fn=predict,
41
+ inputs=gr.inputs.Textbox(label="Input",
42
+ placeholder="Enter sentences here in English or Spanish..."),
43
+ outputs="text",
44
+ examples=examples_list,
45
+ title=title,
46
+ description=description)
47
+
48
+ # Launch gradio
49
+ demo.launch()
model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+
3
+ def create_mt5_small():
4
+ """
5
+ Initializes model and tokenizer.
6
+ """
7
+ checkpoint = 'google/mt5-small'
8
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, return_tensors='pt')
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
10
+
11
+ return model, tokenizer
mt5_amzn_enes_reviews_summarization.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cd5e76f4bed56cd7dc3e669fee65d3b8db01af16091b58645b0a56ef86e9449
3
+ size 1200799301
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.12.0
2
+ gradio==3.44.1
3
+ transformers==4.33.1
4
+ sentencepiece==0.1.99