EmailSubject / app.py
adarshj322's picture
model changes
6569da3
raw
history blame
No virus
1.88 kB
import gradio as gr
import wandb
import torch
from transformers import GPT2Tokenizer,GPT2LMHeadModel
from peft import PeftModel
import os
import re
def clean_text(text):
# Lowercase the text
text = text.lower()
# Remove special characters
text = re.sub(r'\W', ' ', text)
# Remove extra white spaces
text = re.sub(r'\s+', ' ', text).strip()
return text
os.environ["WANDB_API_KEY"] = "d2ad0a7285379c0808ca816971d965fc242d0b5e"
wandb.login()
run = wandb.init(project="Email_subject_gen", job_type="model_loading")
artifact = run.use_artifact('Email_subject_gen/final_model:v0')
artifact_dir = artifact.download()
#tokenizer= GPT2Tokenizer.from_pretrained(artifact_dir)
MODEL_KEY = 'olm/olm-gpt2-dec-2022'
tokenizer= GPT2Tokenizer.from_pretrained(MODEL_KEY)
tokenizer.add_special_tokens({'pad_token':'{PAD}'})
model = GPT2LMHeadModel.from_pretrained(MODEL_KEY)
model.resize_token_embeddings(len(tokenizer))
model.config.dropout = 0.1 # Set dropout rate
model.config.attention_dropout = 0.1
model = PeftModel.from_pretrained(model, artifact_dir)
def generateSubject(email):
clean_text(email)
email = "<email>" + clean_text(email) + "<subject>"
prompts = list()
prompts.append(email)
tokenizer.padding_side='left'
prompts_batch_ids = tokenizer(prompts,
padding=True, truncation=True, return_tensors='pt').to(model.device)
output_ids = model.generate(
**prompts_batch_ids, max_new_tokens=10,
pad_token_id=tokenizer.pad_token_id)
outputs_batch = [seq.split('<subject>')[1] for seq in
tokenizer.batch_decode(output_ids, skip_special_tokens=True)]
tokenizer.padding_side='right'
return outputs_batch[0]
def predict(name):
return "Hello " + name + "!!"
iface = gr.Interface(fn=predict, inputs=gr.inputs.Textbox(), outputs="text")
iface.launch()