EmailSubject / app.py
adarshj322's picture
Chnages to length
04846e4
raw
history blame
1.9 kB
import gradio
import wandb
import torch
from transformers import GPT2Tokenizer,GPT2LMHeadModel
from peft import PeftModel
import os
import re
def clean_text(text):
# Lowercase the text
text = text.lower()
# Remove special characters
text = re.sub(r'\W', ' ', text)
# Remove extra white spaces
text = re.sub(r'\s+', ' ', text).strip()
return text
os.environ["WANDB_API_KEY"] = "d2ad0a7285379c0808ca816971d965fc242d0b5e"
wandb.login()
run = wandb.init(project="Email_subject_gen", job_type="model_loading")
artifact = run.use_artifact('Email_subject_gen/final_model:v0')
artifact_dir = artifact.download()
#tokenizer= GPT2Tokenizer.from_pretrained(artifact_dir)
MODEL_KEY = 'olm/olm-gpt2-dec-2022'
tokenizer= GPT2Tokenizer.from_pretrained(MODEL_KEY)
tokenizer.add_special_tokens({'pad_token':'{PAD}'})
model = GPT2LMHeadModel.from_pretrained(MODEL_KEY)
model.resize_token_embeddings(len(tokenizer))
model.config.dropout = 0.1 # Set dropout rate
model.config.attention_dropout = 0.1
model = PeftModel.from_pretrained(model, artifact_dir)
def generateSubject(email):
clean_text(email)
email = "<email>" + clean_text(email) + "<subject>"
prompts = list()
prompts.append(email)
tokenizer.padding_side='left'
prompts_batch_ids = tokenizer(prompts,
padding=True, truncation=True, return_tensors='pt').to(model.device)
output_ids = model.generate(
**prompts_batch_ids, max_new_tokens=10,
pad_token_id=tokenizer.pad_token_id)
outputs_batch = [seq.split('<subject>')[1] for seq in
tokenizer.batch_decode(output_ids, skip_special_tokens=True)]
tokenizer.padding_side='right'
print(outputs_batch)
return outputs_batch[0]
def predict(name):
return "Hello " + name + "!!"
iface = gradio.Interface(fn=generateSubject, inputs="text", outputs="text")
iface.launch()