zihanliu commited on
Commit
3916ade
1 Parent(s): b80bf5a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -11
README.md CHANGED
@@ -107,30 +107,27 @@ contexts = [
107
 
108
  ## convert query into a format as follows:
109
  ## user: {user}\nagent: {agent}\nuser: {user}
110
- formatted_query = ""
111
- for turn in query:
112
- formatted_query += turn['role'] + ": " + turn['content'] + "\n"
113
- formatted_query = formatted_query.strip()
114
 
115
  ## get query and context embeddings
116
  query_input = tokenizer(formatted_query, return_tensors='pt')
117
  ctx_input = tokenizer(contexts, padding=True, return_tensors='pt')
118
- query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
119
- ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
120
 
121
- # Compute similarity scores using dot product
122
- score1 = query_emb @ ctx_emb[0]
123
- score2 = query_emb @ ctx_emb[1]
 
 
124
  ```
125
 
126
  ## License
127
  Dragon-multiturn is built on top of [Dragon](https://arxiv.org/abs/2302.07452). We refer users to the original license of the Dragon model.
128
 
129
-
130
  ## Correspondence to
131
  Zihan Liu (zihanl@nvidia.com), Wei Ping (wping@nvidia.com)
132
 
133
-
134
  ## Citation
135
  <pre>
136
  @article{liu2024chatqa,
 
107
 
108
  ## convert query into a format as follows:
109
  ## user: {user}\nagent: {agent}\nuser: {user}
110
+ formatted_query = '\n'.join([turn['role'] + ": " + turn['content'] for turn in messages]).strip()
 
 
 
111
 
112
  ## get query and context embeddings
113
  query_input = tokenizer(formatted_query, return_tensors='pt')
114
  ctx_input = tokenizer(contexts, padding=True, return_tensors='pt')
115
+ query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :] # (1, emb_dim)
116
+ ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :] # (num_ctx, emb_dim)
117
 
118
+ ## Compute similarity scores using dot product
119
+ similarities = query_emb.matmul(ctx_emb.transpose(0, 1)) # (1, num_ctx)
120
+
121
+ ## rank the similarity (from highest to lowest)
122
+ ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (1, num_ctx)
123
  ```
124
 
125
  ## License
126
  Dragon-multiturn is built on top of [Dragon](https://arxiv.org/abs/2302.07452). We refer users to the original license of the Dragon model.
127
 
 
128
  ## Correspondence to
129
  Zihan Liu (zihanl@nvidia.com), Wei Ping (wping@nvidia.com)
130
 
 
131
  ## Citation
132
  <pre>
133
  @article{liu2024chatqa,