lionguard-sexual-v1.0 / inference.py
shaunkhoo's picture
fix: update code to latest
ac4b19e
raw
history blame contribute delete
No virus
2.67 kB
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
import sys
import json
import onnxruntime as rt
repo_path = "govtech/lionguard-sexual-v1.0"
config_path = hf_hub_download(repo_id=repo_path, filename="config.json")
with open(config_path, 'r') as f:
config = json.load(f)
def get_embeddings(device, data):
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
model = AutoModel.from_pretrained(config['embedding_model'])
model.eval()
model.to(device)
batch_size = config['batch_size']
num_batches = int(np.ceil(len(data)/batch_size))
output = []
for i in range(num_batches):
sentences = data[i*batch_size:(i+1)*batch_size]
encoded_input = tokenizer(sentences, max_length=config['max_length'], padding=True, truncation=True, return_tensors='pt')
encoded_input.to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
output.extend(sentence_embeddings.cpu().numpy())
return np.array(output)
def predict(batch_text):
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
embeddings = get_embeddings(device, batch_text)
embeddings_df = pd.DataFrame(embeddings)
# Load the model
model_fp = hf_hub_download(repo_id=repo_path, filename=config['model_name'])
session = rt.InferenceSession(model_fp)
# Prepare input data
input_name = session.get_inputs()[0].name
X_input = np.array(embeddings_df, dtype=np.float32)
# Run inference
outputs = session.run(None, {input_name: X_input})
# If calibrated, return only the prediction for the unsafe class
if config['calibrated']:
scores = [output[1][1] for output in outputs[1]]
# If not calibrated, we will only get a 1D array for the unsafe class
else:
scores = outputs[1].flatten()
# Generate the predictions depending on the recommended threshold score
predictions = [1 if score >= config['threshold'] else 0 for score in scores]
return {
'scores': scores,
'predictions': predictions
}
if __name__ == "__main__":
# Load the data
input_data = sys.argv[1]
batch_text = json.loads(input_data)
# Generate the scores and predictions
outputs = predict(batch_text)
for i in range(len(batch_text)):
print(f"[Text {i+1}] Score: {outputs['scores'][i]:.3f}, Prediction: {outputs['predictions'][i]}")