Edit model card
import torch
from transformers.models.bert import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification


# Load model architecture from COLD and load fine-tuned params.
model_name = "thu-coai/roberta-base-cold"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model_path = "finetuned_cold_LoL.pth" # Could be downloaded in this repo.
model.load_state_dict(torch.load(model_path))


# Demo for toxicity detection
texts = ['狠狠地导', '卡了哟', 'gala有卡莎皮肤,你们这些小黑子有吗?', '早改了,改成回血了']
model_input = tokenizer(texts, return_tensors="pt", padding=True)
model_output = model(**model_input, return_dict=False)
prediction = torch.argmax(model_output[0].cpu(), dim=-1)
prediction = [p.item() for p in prediction]
# prediction = [1, 0, 1, 0] # 1 for toxic, 0 for non-toxic
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .