GPT_124M / pipeline_gpt.py
samkeet's picture
Upload GPT124MTextGenerationPipeline
2582b26 verified
raw
history blame
1.99 kB
# Importing necessary libraries
import torch
from transformers import Pipeline
class GPT124MTextGenerationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
"""
Organizes and sanitizes input parameters into separate dictionaries for:
- Preprocessing (encoding)
- Model forward pass (generation settings)
- Postprocessing (decoding)
"""
preprocess_kwargs = {}
forward_kwargs = {
"max_length": kwargs.get("max_length", 50),
"do_sample": kwargs.get("do_sample", True),
"top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95),
"temperature": kwargs.get("temperature", 0.9),
"device": kwargs.get("device", "cpu"),
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, prompt_text: str, **preprocess_kwargs):
"""
Encodes input text into token IDs using the tokenizer and converts it to a PyTorch tensor.
"""
assert (
isinstance(prompt_text, str) and len(prompt_text) > 0
), "prompt_text must be a non-empty string"
# Encode the input text
input_ids = self.tokenizer.encode(prompt_text)
# Convert to a PyTorch tensor and ensure it has batch dimension (1, seq_len)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
return {"input_ids": input_tensor}
def _forward(self, model_inputs, **forward_kwargs):
"""
Forwards the tokenized input to the model's generate method.
"""
return self.model.generate(**model_inputs, **forward_kwargs)
def postprocess(self, model_output, **postprocess_kwargs):
"""
Decodes token ID into human-readable text using the tokenizer.
"""
return self.tokenizer.decode(model_output)