File size: 4,392 Bytes
5f8cc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d5bee
5f8cc30
 
 
 
 
 
 
 
99d5bee
5f8cc30
 
 
 
99d5bee
5f8cc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99d5bee
 
 
 
 
 
 
 
5f8cc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2a12a2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from safetensors import safe_open\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from transformers import AutoModel, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "148ce181",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First clone the model locally\n",
    "!git clone https://huggingface.co/MongoDB/mdbr-leaf-ir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ba9ec6c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Then load it\n",
    "MODEL = \"mdbr-leaf-ir\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
    "model = AutoModel.from_pretrained(MODEL, add_pooling_layer=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ebaf1a76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Similarities:\n",
      "tensor([[0.6857, 0.4598],\n",
      "        [0.4238, 0.5723]])\n"
     ]
    }
   ],
   "source": [
    "tensors = {}\n",
    "with safe_open(MODEL + \"/2_Dense/model.safetensors\", framework=\"pt\") as f:\n",
    "    for k in f.keys():\n",
    "        tensors[k] = f.get_tensor(k)\n",
    "\n",
    "W_out = torch.nn.Linear(in_features=384, out_features=768, bias=True)\n",
    "W_out.load_state_dict({\n",
    "    \"weight\": tensors[\"linear.weight\"], \n",
    "    \"bias\": tensors[\"linear.bias\"]\n",
    "})\n",
    "\n",
    "_ = model.eval()\n",
    "_ = W_out.eval()\n",
    "\n",
    "# Example queries and documents  \n",
    "queries = [\n",
    "    \"What is machine learning?\",  \n",
    "    \"How does neural network training work?\"  \n",
    "]  \n",
    "  \n",
    "documents = [  \n",
    "    \"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\",  \n",
    "    \"Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors.\"  \n",
    "]\n",
    "\n",
    "# Tokenize\n",
    "QUERY_PREFIX = 'Represent this sentence for searching relevant passages: '\n",
    "queries_with_prefix = [QUERY_PREFIX + query for query in queries]\n",
    "\n",
    "query_tokens = tokenizer(queries_with_prefix, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
    "document_tokens =  tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
    "\n",
    "# Perform Inference\n",
    "with torch.inference_mode():\n",
    "    y_queries = model(**query_tokens).last_hidden_state\n",
    "    y_docs = model(**document_tokens).last_hidden_state\n",
    "\n",
    "    # perform pooling\n",
    "    y_queries = y_queries * query_tokens.attention_mask.unsqueeze(-1)\n",
    "    y_queries_pooled = y_queries.sum(dim=1) / query_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
    "\n",
    "    y_docs = y_docs * document_tokens.attention_mask.unsqueeze(-1)\n",
    "    y_docs_pooled = y_docs.sum(dim=1) / document_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
    "\n",
    "    # map to desired output dimension\n",
    "    y_queries_out = W_out(y_queries_pooled)\n",
    "    y_docs_out = W_out(y_docs_pooled)\n",
    "\n",
    "    # normalize and return\n",
    "    query_embeddings = F.normalize(y_queries_out, dim=-1)\n",
    "    document_embeddings = F.normalize(y_docs_out, dim=-1)\n",
    "\n",
    "similarities = query_embeddings @ document_embeddings.T\n",
    "print(f\"Similarities:\\n{similarities}\")\n",
    "\n",
    "# Similarities:\n",
    "#  tensor([[0.6857, 0.4598],\n",
    "#          [0.4238, 0.5723]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458cec94",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alexis",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}