File size: 4,392 Bytes
5f8cc30 99d5bee 5f8cc30 99d5bee 5f8cc30 99d5bee 5f8cc30 99d5bee 5f8cc30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "2a12a2b3",
"metadata": {},
"outputs": [],
"source": [
"from safetensors import safe_open\n",
"import torch\n",
"from torch.nn import functional as F\n",
"from transformers import AutoModel, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "148ce181",
"metadata": {},
"outputs": [],
"source": [
"# First clone the model locally\n",
"!git clone https://huggingface.co/MongoDB/mdbr-leaf-ir"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ba9ec6c7",
"metadata": {},
"outputs": [],
"source": [
"# Then load it\n",
"MODEL = \"mdbr-leaf-ir\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"model = AutoModel.from_pretrained(MODEL, add_pooling_layer=False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ebaf1a76",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Similarities:\n",
"tensor([[0.6857, 0.4598],\n",
" [0.4238, 0.5723]])\n"
]
}
],
"source": [
"tensors = {}\n",
"with safe_open(MODEL + \"/2_Dense/model.safetensors\", framework=\"pt\") as f:\n",
" for k in f.keys():\n",
" tensors[k] = f.get_tensor(k)\n",
"\n",
"W_out = torch.nn.Linear(in_features=384, out_features=768, bias=True)\n",
"W_out.load_state_dict({\n",
" \"weight\": tensors[\"linear.weight\"], \n",
" \"bias\": tensors[\"linear.bias\"]\n",
"})\n",
"\n",
"_ = model.eval()\n",
"_ = W_out.eval()\n",
"\n",
"# Example queries and documents \n",
"queries = [\n",
" \"What is machine learning?\", \n",
" \"How does neural network training work?\" \n",
"] \n",
" \n",
"documents = [ \n",
" \"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\", \n",
" \"Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors.\" \n",
"]\n",
"\n",
"# Tokenize\n",
"QUERY_PREFIX = 'Represent this sentence for searching relevant passages: '\n",
"queries_with_prefix = [QUERY_PREFIX + query for query in queries]\n",
"\n",
"query_tokens = tokenizer(queries_with_prefix, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
"document_tokens = tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
"\n",
"# Perform Inference\n",
"with torch.inference_mode():\n",
" y_queries = model(**query_tokens).last_hidden_state\n",
" y_docs = model(**document_tokens).last_hidden_state\n",
"\n",
" # perform pooling\n",
" y_queries = y_queries * query_tokens.attention_mask.unsqueeze(-1)\n",
" y_queries_pooled = y_queries.sum(dim=1) / query_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
"\n",
" y_docs = y_docs * document_tokens.attention_mask.unsqueeze(-1)\n",
" y_docs_pooled = y_docs.sum(dim=1) / document_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
"\n",
" # map to desired output dimension\n",
" y_queries_out = W_out(y_queries_pooled)\n",
" y_docs_out = W_out(y_docs_pooled)\n",
"\n",
" # normalize and return\n",
" query_embeddings = F.normalize(y_queries_out, dim=-1)\n",
" document_embeddings = F.normalize(y_docs_out, dim=-1)\n",
"\n",
"similarities = query_embeddings @ document_embeddings.T\n",
"print(f\"Similarities:\\n{similarities}\")\n",
"\n",
"# Similarities:\n",
"# tensor([[0.6857, 0.4598],\n",
"# [0.4238, 0.5723]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "458cec94",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "alexis",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|