|
from SCRL_new.scrl.model import load_model |
|
from transformers import AutoTokenizer |
|
import re |
|
from abs_compressor import AbstractCompressor |
|
|
|
|
|
class SCRLCompressor(AbstractCompressor): |
|
|
|
def __init__(self, model_dir: str, device: str = "cpu", tokenizer_dir: str = "sentence-transformers/paraphrase-distilroberta-base-v2"): |
|
self.model_dir = model_dir |
|
self.device = device |
|
self.model = load_model(self.model_dir, self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) |
|
|
|
def compress(self, original_prompt: str, ratio: float = 0.5, max_length: int = 256) -> dict: |
|
original_tokens = len(self.gpt_tokenizer.encode(original_prompt)) |
|
|
|
|
|
sources = re.findall(r'.{%d}' % max_length, original_prompt.strip()) |
|
|
|
if sources: |
|
summaries = self.model.predict(sources, self.tokenizer, self.device) |
|
|
|
|
|
|
|
compressed_prompt = "" |
|
for s in summaries: |
|
compressed_prompt += s |
|
|
|
compressed_tokens = len(self.gpt_tokenizer.encode(compressed_prompt)) |
|
|
|
result = { |
|
'compressed_prompt': compressed_prompt, |
|
'ratio': compressed_tokens / original_tokens, |
|
'original_tokens': original_tokens, |
|
'compressed_tokens': compressed_tokens, |
|
} |
|
|
|
return result |
|
else: |
|
result = { |
|
'compressed_prompt': "", |
|
'ratio': 0, |
|
'original_tokens': "", |
|
'compressed_tokens': "", |
|
} |
|
return result |
|
|
|
|