YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Uses
You can run example code on colab
- You should create
ModernBertForQueryComparison
class first
import torch
from transformers import AutoTokenizer, ModernBertPreTrainedModel, ModernBertModel
from torch import nn
# 定義模型類別(與訓練時相同)
class ModernBertForQueryComparison(ModernBertPreTrainedModel):
"""
繼承 ModernBertPreTrainedModel,可以使用 from_pretrained
"""
def __init__(self, config):
super().__init__(config)
self.bert = ModernBertModel(config)
self.dropout = nn.Dropout(config.mlp_dropout if hasattr(config, 'mlp_dropout') else 0.1)
self.score_predictor = nn.Linear(config.hidden_size, 1)
# 初始化新增層
self.post_init()
def forward(self,
article_input_ids=None,
article_attention_mask=None,
sentence_input_ids=None,
sentence_attention_mask=None,
labels=None):
"""
return_dict=True 時, ModernBertModel 預設會回傳 BaseModelOutput,
這裡從 last_hidden_state 取 [CLS] 位置上的向量,用來做分數預測。
"""
# === 文章查詢 ===
article_outputs = self.bert(
input_ids=article_input_ids,
attention_mask=article_attention_mask,
return_dict=True
)
# [batch, seq_len, hidden_dim]
article_cls = article_outputs.last_hidden_state[:, 0, :]
article_cls = self.dropout(article_cls)
article_score = self.score_predictor(article_cls) # [batch, 1]
# === 句子查詢 ===
sentence_outputs = self.bert(
input_ids=sentence_input_ids,
attention_mask=sentence_attention_mask,
return_dict=True
)
sentence_cls = sentence_outputs.last_hidden_state[:, 0, :]
sentence_cls = self.dropout(sentence_cls)
sentence_score = self.score_predictor(sentence_cls)
# relative_score = sigmoid(sentence) - sigmoid(article)
# 但這裡我們可以先不做 sigmoid, 用 BCEWithLogitsLoss 會更直接
relative_score = sentence_score - article_score # [batch, 1]
loss = None
if labels is not None:
# 這裡 labels 為 0 or 1;我們用 BCE with logits:
# predict = relative_score => >0表示 sentence較好, <0表示 article 較好
# 因此把相對分數丟進 BCEWithLogitsLoss, label=1 => sentence好
# label=0 => article好
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(relative_score.view(-1), labels.view(-1))
# 回傳
# 可回傳 (loss, relative_score, article_score, sentence_score) 或 dict
return {
'loss': loss,
'relative_score': relative_score,
'article_score': article_score,
'sentence_score': sentence_score,
}
- then create predict_better_query function
# 預測函數
def predict_better_query(model, tokenizer, query, article_query, sentence_query, device, max_length=8192):
model.eval()
article_encoding = tokenizer(
query,
article_query,
add_special_tokens=True,
max_length=max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
sentence_encoding = tokenizer(
query,
sentence_query,
add_special_tokens=True,
max_length=max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
article_input_ids = article_encoding['input_ids'].to(device)
article_attention_mask = article_encoding['attention_mask'].to(device)
sentence_input_ids = sentence_encoding['input_ids'].to(device)
sentence_attention_mask = sentence_encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(
article_input_ids=article_input_ids,
article_attention_mask=article_attention_mask,
sentence_input_ids=sentence_input_ids,
sentence_attention_mask=sentence_attention_mask,
labels=None
)
relative_score = outputs['relative_score'].item()
article_score = outputs['article_score'].item()
sentence_score = outputs['sentence_score'].item()
# relative_score > 0 => sentence_query 更好
is_sentence_better = relative_score > 0
result = {
'is_sentence_better': is_sentence_better,
'relative_score': relative_score,
'article_score': article_score,
'sentence_score': sentence_score
}
return result
- inference
def main():
model_dir = "CheWei/ModernBERT_16x2_1e-5_8192" # 直接指定模型目錄
# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 載入預先訓練的模型和分詞器
print(f"Loading model from {model_dir}...")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = ModernBertForQueryComparison.from_pretrained(model_dir)
model.to(device)
print("Model loaded successfully!")
# 定義測試例子
test_examples = [
{
"query": "Python programming tutorials",
"article_query": "Python programming guides and examples",
"sentence_query": "How to learn Python programming"
},
{
"query": "Healthy breakfast ideas",
"article_query": "Nutritious breakfast recipes for busy mornings",
"sentence_query": "Quick and healthy breakfast options"
},
{
"query": "史丹佛大學課程",
"article_query": "史丹佛大學開設的各類專業課程介紹與選擇指南",
"sentence_query": "史丹佛大學有哪些熱門課程可以選擇"
}
]
# 遍歷並處理每個例子
print("\n=== Running Test Examples ===")
for i, example in enumerate(test_examples):
print(f"\nExample {i+1}:")
print(f"Query: {example['query']}")
print(f"Article query: {example['article_query']}")
print(f"Sentence query: {example['sentence_query']}")
# 進行預測
result = predict_better_query(
model,
tokenizer,
example['query'],
example['article_query'],
example['sentence_query'],
device
)
# 顯示結果
print("--- Results ---")
print(f"Article query score: {result['article_score']:.4f}")
print(f"Sentence query score: {result['sentence_score']:.4f}")
print(f"Relative score: {result['relative_score']:.4f}")
if result['is_sentence_better']:
print("Conclusion: ✓ Sentence query is better")
else:
print("Conclusion: ✓ Article query is better")
print("-" * 50)
if __name__ == "__main__":
main()
- Downloads last month
- 92
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.