email-generator / email_generator.py
Justin-Sima
Initial upload.
465f2e4
""" Functions for loading email generator."""
import torch
import transformers
DEVICE = 0 if torch.cuda.is_available() else -1
AVAILABLE_GENERATORS = [
'pszemraj/opt-350m-email-generation',
'pszemraj/opt-350m-email-generation',
'postbot/gpt2-medium-emailgen',
'sagorsarker/emailgenerator'
]
DEFAULT_GENERATOR = 'pszemraj/opt-350m-email-generation'
class EmailGenerator:
""" Class that loads and wraps a HuggingFace email generation pipeline."""
def __init__(self, model_tag: str) -> None:
""" Initialize HuggingFace email generation pipeline.
Args:
model_tag (str): Model name.
"""
self.tag = model_tag
self.generator = transformers.pipeline(
'text-generation', model_tag,
use_fast=True, do_sample=False,
device=DEVICE
)
def generate(self, prompt: str, max_tokens: int) -> str:
""" Generate a sample from a given prompt.
Args:
prompt (str): Prompting for email generator.
max_tokens (int): Maximum number of tokens to return.
Returns:
str: Generated text.
"""
output = self.generator(prompt, max_length=max_tokens)
return output[0]['generated_text']
def __str__(self):
return f'EmailGenerator({self.tag})'
def set_global_generator(model_tag: str=DEFAULT_GENERATOR):
""" Set global parameter 'generator' as specified EmailGenerator."""
global generator
generator = EmailGenerator(model_tag=model_tag)
def generator_exists():
""" Check if global variable 'generator' has been defined."""
return 'generator' in globals()
def generate_email(model_tag: str, prompt: str, max_tokens: int):
""" Check for generator and create prompt.
Initialize correct generator if incorrect generator or no generator is found.
"""
if not generator_exists() or generator.tag != model_tag:
set_global_generator(model_tag=model_tag)
return generator.generate(prompt, max_tokens=max_tokens)