Spaces:
Runtime error
Runtime error
File size: 2,243 Bytes
10b912d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from abs_compressor import AbstractCompressor
class KiSCompressor(AbstractCompressor):
def __init__(self, DEVICE: str = 'cpu', model_dir: str = 'philippelaban/keep_it_simple'):
self.DEVICE = DEVICE
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side='right', pad_token='<|endoftext|')
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = 'right'
self.kis_model = AutoModelForCausalLM.from_pretrained(model_dir)
self.kis_model.to(self.DEVICE)
# if self.tokenizer.pad_token is None:
# self.tokenizer.pad_token = self.tokenizer.eos_token
# self.kis_model.eval()
def compress(self, original_prompt: str, ratio: float = 0.5, max_length: int = 150, num_beams: int = 4, do_sample: bool = True, num_return_sequences: int = 1, target_index: int = 0) -> dict:
original_tokens = len(self.gpt_tokenizer.encode(original_prompt))
start_id = self.tokenizer.bos_token_id
print(self.tokenizer.padding_side)
tokenized_paragraph = [(self.tokenizer.encode(text=original_prompt) + [start_id])]
input_ids = torch.LongTensor(tokenized_paragraph)
if self.DEVICE == 'cuda':
input_ids = input_ids.type(torch.cuda.LongTensor)
output_ids = self.kis_model.generate(input_ids, max_length=max_length, num_beams=num_beams, do_sample=do_sample,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.eos_token_id)
output_ids = output_ids[:, input_ids.shape[1]:]
output = self.tokenizer.batch_decode(output_ids)
output = [o.replace(self.tokenizer.eos_token, "") for o in output]
compressed_prompt = output[target_index]
compressed_tokens = len(self.gpt_tokenizer.encode(compressed_prompt))
result = {
'compressed_prompt': compressed_prompt,
'ratio': compressed_tokens / original_tokens,
'original_tokens': original_tokens,
'compressed_tokens': compressed_tokens,
}
return result
|