Alexvatti's picture
Create app.py
868fbbe verified
raw
history blame
2.64 kB
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import re
import torch
import pickle
import json
# Define paths
MODEL_PATH = "spam_model.pth"
VOCAB_PATH = "vocab.pkl"
class TransformerEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=1, d_ff=512, num_layers=1, vocab_size=10000, max_seq_len=100, dropout=0.1):
super(TransformerEncoder, self).__init__()
# Embedding & Positional Encoding
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
# Transformer Encoder Layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=num_heads,
dim_feedforward=d_ff,
dropout=dropout,
activation='relu',
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Classification Head
self.fc = nn.Linear(d_model, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
x = self.encoder(x) # Pass through transformer
x = x[:, 0, :] # Take first token's output (CLS token equivalent)
x = self.fc(x)
return self.sigmoid(x) # Binary classification (spam or not)
with open(VOCAB_PATH, "rb") as f:
vocab = pickle.load(f)
# Load model
device = torch.device("cpu") # Change to "cuda" if using GPU
model = TransformerEncoder(d_model=256, num_heads=1, num_layers=1, vocab_size=len(vocab), max_seq_len=100).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval() # Set model to evaluation mode
print("✅ Model and vocabulary loaded successfully!")
def simple_tokenize(text):
return re.findall(r"\b\w+\b", text.lower())
def predict(text, model, vocab, max_len=1000):
model.eval()
tokens = simple_tokenize(text.lower())
token_ids = [vocab.get(word, vocab['<UNK>']) for word in tokens]
token_ids += [vocab['<PAD>']] * (max_len - len(token_ids)) # Pad if needed
input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
with torch.no_grad():
output = model(input_tensor).squeeze().item()
return "Spam" if output > 0.5 else "Ham"
# Test prediction
sample_text = "FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv"
print(f"Prediction: {predict(sample_text, model, vocab)}")