File size: 3,343 Bytes
0e04b12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d656b
0e04b12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf0efa
 
0e04b12
 
 
 
 
 
 
74d656b
 
0e04b12
 
74d656b
d76f6bc
0e04b12
 
 
 
 
 
 
74d656b
0e04b12
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# create dataset class
from torch.utils.data import Dataset, DataLoader
import torch
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time


class Dataset4Summarization(Dataset):
	def __init__(self, data, tokenizer, max_length=1024*3, chunk_length =1024):
		self.data = data
		self.tokenizer = tokenizer
		self.max_length = max_length
		self.chunk_length = chunk_length

	def __len__(self):
		return len(self.data)
	
	def chunking(self, text):
		chunks = []
		for i in range(0, self.max_length, self.chunk_length):
			chunks.append(text[i:i+self.chunk_length])
		return chunks

	def __getitem__(self, idx):
		sample = self.data[idx]
		inputs = self.tokenizer(sample, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)

		list_chunk = self.chunking(inputs['input_ids'].squeeze())
		list_attention_mask = self.chunking(inputs['attention_mask'].squeeze())


		return {
			'list_input_ids': list_chunk,
			'list_att_mask' : list_attention_mask,
		}
	

def process_data_infer(data):
	single_documents = data.get('single_documents', [])

	
	result = []
	for doc in single_documents:
		raw_text = doc.get('raw_text', '')
		result.append(raw_text)

	return " ".join(result)


def processing_data_infer(input_file):
	all_results = []
	
	with open(input_file, 'r', encoding='utf-8') as file:
		for line in file:
			data = json.loads(line.strip())
			result = process_data_infer(data)
			all_results.append(result)

	return all_results

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")

device = torch.device('cpu')
model.to(device)

model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))

# For other demo purpose, you just need to make sure data is list of documents [document1, document2]

# batch_size need to be 1,
@torch.no_grad()
def infer_2_hier(model, data_loader, device, tokenizer):
    model.eval()
    start = time.time()
    all_summaries = []
    for iter in data_loader:
        summaries = []
        inputs = iter['list_input_ids']
        att_mask = iter['list_att_mask']
        
        for i in range(len(inputs)):
            print(f"input {i}")
			# Check if the input tensor is all zeros
            if torch.all(inputs[i] == 0):
                # If the input is all zeros, skip this iteration
                continue
            else:
                summary = model.generate(inputs[i].to(device),
                                         attention_mask=att_mask[i].to(device),
                                         max_length=128,
                                         num_beams=4,
                                         num_return_sequences=1, no_repeat_ngram_size=3)
                summaries.append(summary)
        summaries = torch.cat(summaries, dim = 1)
	
        all_summaries.append(tokenizer.decode(summaries.squeeze(), skip_special_tokens=True))
    
    end = time.time()
    print(f"Time: {end-start}")
    return all_summaries

def vit5_infer(data):
	dataset = Dataset4Summarization(data, tokenizer)
	data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
	result = infer_2_hier(model, data_loader, device, tokenizer)
	return result