kaxap commited on
Commit
f61780d
0 Parent(s):

Duplicate from kaxap/r-jokes-multilingual-e5-large

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +59 -0
  4. requirements.txt +5 -0
  5. rjokes-embeddings.npy +3 -0
  6. rjokes.csv +0 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Semantic search on r/jokes using Multilingual E5 Large
3
+ emoji: 🐢
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: kaxap/r-jokes-multilingual-e5-large
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ import torch.nn.functional as F
7
+
8
+ from torch import Tensor
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+
13
+ def average_pool(last_hidden_states: Tensor,
14
+ attention_mask: Tensor) -> Tensor:
15
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
16
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
17
+
18
+
19
+ df = pd.read_csv('rjokes.csv')
20
+ data_embeddings = np.load("rjokes-embeddings.npy")
21
+
22
+ print("loading the model...")
23
+ tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
24
+ model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
25
+
26
+ with gr.Blocks() as demo:
27
+ chatbot = gr.Chatbot()
28
+ msg = gr.Textbox(label="r/jokes semantic search query", placeholder="for example, \"programming and religion\"")
29
+ clear = gr.ClearButton([msg, chatbot])
30
+
31
+ def respond(message, chat_history):
32
+ batch_dict = tokenizer(["query: " + message], max_length=512, padding=True, truncation=True, return_tensors='pt')
33
+
34
+ outputs = model(**batch_dict)
35
+ input_embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
36
+
37
+ # normalize embeddings
38
+ input_embedding = F.normalize(input_embedding, p=2, dim=1)
39
+ input_embedding = input_embedding[0].tolist()
40
+
41
+ # Compute cosine similarities
42
+ input_embedding = np.array(input_embedding).reshape(1, -1)
43
+ cos_similarities = cosine_similarity(data_embeddings, input_embedding).flatten()
44
+
45
+ # Get top k similar points' indices
46
+ k = 5 # replace with your value of k
47
+ top_k_idx = cos_similarities.argsort()[-k:][::-1]
48
+
49
+ # Get corresponding 'text' for top k similar points
50
+ top_k_text = df['text'].iloc[top_k_idx].tolist()
51
+
52
+ bot_message = "\n".join(f"{i+1}. {top_k_text[i]}" for i in range(len(top_k_text)))
53
+
54
+ chat_history.append((message, bot_message))
55
+ return "", chat_history
56
+
57
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
58
+
59
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ pandas
4
+ transformers
5
+ scikit-learn
rjokes-embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20e61c6a0ebb6fecf9329b1f77ba6d01ff570e8e846fca16cab486830a7350bf
3
+ size 354214016
rjokes.csv ADDED
The diff for this file is too large to render. See raw diff