gvozdev commited on
Commit
56ebedb
1 Parent(s): 7dc82e6

Upload first working version

Browse files
Files changed (3) hide show
  1. README.md +1 -7
  2. main.py +107 -0
  3. requirements.txt +81 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
  title: Subspace
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.35.2
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Subspace
3
+ app_file: main.py
 
 
4
  sdk: gradio
5
  sdk_version: 3.35.2
 
 
6
  ---
 
 
main.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
+ import pandas as pd
6
+ import os
7
+
8
+ os.environ['CURL_CA_BUNDLE'] = ''
9
+
10
+ # Load dataset
11
+ issues_dataset = load_dataset("gvozdev/subspace-info-v2", split="train")
12
+
13
+ # Load tokenizer and model
14
+ model_ckpt = "sentence-transformers/all-MiniLM-L12-v1"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
16
+ model = AutoModel.from_pretrained(model_ckpt, trust_remote_code=True)
17
+
18
+ # Text concatenation - not used in this case as mapping only on subject returns better results
19
+ # def concatenate_text(examples):
20
+ # return {
21
+ # "text": examples["subject"]
22
+ # + " \n "
23
+ # + examples["details"]
24
+ # }
25
+
26
+
27
+ issues_dataset = issues_dataset.map()
28
+
29
+ # To speed up embedding, we can switch to GPU (change device to "cuda") - for larger models
30
+ device = torch.device("cpu")
31
+ model.to(device)
32
+
33
+
34
+ # CLS pooling on model’s outputs: collect the last hidden state for the special [CLS] token
35
+ def cls_pooling(model_output):
36
+ return model_output.last_hidden_state[:, 0]
37
+
38
+
39
+ # Tokenize a list of documents, place the tensors on the CPU/GPU, feed them to the model,
40
+ # and apply CLS pooling to the outputs
41
+ def get_embeddings(text_list):
42
+ encoded_input = tokenizer(
43
+ text_list, padding=True, truncation=True, return_tensors="pt"
44
+ )
45
+ encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
46
+ model_output = model(**encoded_input)
47
+ return cls_pooling(model_output)
48
+
49
+ # Test if the function works
50
+ # embedding = get_embeddings(issues_dataset["details"][0])
51
+ # print(embedding.shape)
52
+
53
+
54
+ # Use Dataset.map() to apply get_embeddings() function to each row in the dataset and create a new "embeddings" column
55
+ # Convert the embeddings to NumPy arrays as Datasets requires this format when we try to index them with FAISS
56
+ embeddings_dataset = issues_dataset.map(
57
+ lambda x: {"embeddings": get_embeddings(x["subject"]).detach().cpu().numpy()[0]}
58
+ )
59
+
60
+ # Create a FAISS index
61
+ embeddings_dataset.add_faiss_index(column="embeddings")
62
+
63
+
64
+ #
65
+ def answer_question(question):
66
+ # Get an embedding for the question
67
+ question_embedding = get_embeddings([question]).cpu().detach().numpy()
68
+
69
+ # Find a nearest neighbor in our dataset
70
+ scores, samples = embeddings_dataset.get_nearest_examples(
71
+ "embeddings", question_embedding, k=1
72
+ )
73
+
74
+ samples_df = pd.DataFrame.from_dict(samples)
75
+
76
+ # This part is needed in case we use k>1
77
+ # samples_df["scores"] = scores
78
+ # samples_df.sort_values("scores", ascending=False, inplace=True)
79
+
80
+ return samples_df["details"].values[0]
81
+
82
+
83
+ # Gradio interface
84
+ title = "Subspace Docs bot"
85
+ description = '<p style="text-align: center;">This is a bot trained on Subspace Network documentation ' \
86
+ 'to answer the most common questions about the project</p>'
87
+
88
+
89
+ def chat(message, history):
90
+ history = history or []
91
+ response = answer_question(message)
92
+ history.append((message, response))
93
+ return history, history
94
+
95
+
96
+ iface = gr.Interface(
97
+ chat,
98
+ ["text", "state"],
99
+ ["chatbot", "state"],
100
+ allow_flagging="never",
101
+ title=title,
102
+ description=description,
103
+ theme="Monochrome",
104
+ examples=["What is Subspace Network?", "Do you have a token?", "System requirements"]
105
+ )
106
+
107
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ anyio==3.7.0
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ certifi==2023.5.7
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ contourpy==1.1.0
12
+ cycler==0.11.0
13
+ datasets==2.13.1
14
+ dill==0.3.6
15
+ einops==0.6.1
16
+ evaluate==0.4.0
17
+ exceptiongroup==1.1.1
18
+ faiss-cpu==1.7.4
19
+ fastapi==0.98.0
20
+ ffmpy==0.3.0
21
+ filelock==3.12.2
22
+ fonttools==4.40.0
23
+ frozenlist==1.3.3
24
+ fsspec==2023.6.0
25
+ gradio==3.35.2
26
+ gradio_client==0.2.7
27
+ h11==0.14.0
28
+ httpcore==0.17.2
29
+ httpx==0.24.1
30
+ huggingface-hub==0.15.1
31
+ idna==3.4
32
+ Jinja2==3.1.2
33
+ jsonschema==4.17.3
34
+ kiwisolver==1.4.4
35
+ linkify-it-py==2.0.2
36
+ markdown-it-py==2.2.0
37
+ MarkupSafe==2.1.3
38
+ matplotlib==3.7.1
39
+ mdit-py-plugins==0.3.3
40
+ mdurl==0.1.2
41
+ mpmath==1.3.0
42
+ multidict==6.0.4
43
+ multiprocess==0.70.14
44
+ networkx==3.1
45
+ numpy==1.25.0
46
+ orjson==3.9.1
47
+ packaging==23.1
48
+ pandas==2.0.3
49
+ Pillow==9.5.0
50
+ pyarrow==12.0.1
51
+ pydantic==1.10.9
52
+ pydub==0.25.1
53
+ Pygments==2.15.1
54
+ pyparsing==3.1.0
55
+ pyrsistent==0.19.3
56
+ python-dateutil==2.8.2
57
+ python-multipart==0.0.6
58
+ pytz==2023.3
59
+ PyYAML==6.0
60
+ regex==2023.6.3
61
+ requests==2.27.1
62
+ responses==0.18.0
63
+ safetensors==0.3.1
64
+ semantic-version==2.10.0
65
+ six==1.16.0
66
+ sniffio==1.3.0
67
+ starlette==0.27.0
68
+ sympy==1.12
69
+ tokenizers==0.13.3
70
+ toolz==0.12.0
71
+ torch==2.0.1
72
+ tqdm==4.65.0
73
+ transformers==4.30.2
74
+ typing_extensions==4.7.0
75
+ tzdata==2023.3
76
+ uc-micro-py==1.0.2
77
+ urllib3==2.0.3
78
+ uvicorn==0.22.0
79
+ websockets==11.0.3
80
+ xxhash==3.2.0
81
+ yarl==1.9.2