not-lain commited on
Commit
07ffad3
β€’
1 Parent(s): 37830e1

🌘wπŸŒ–

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +161 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: RAG
3
- emoji: 🌍
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
 
1
  ---
2
  title: RAG
3
+ emoji: 🌘wπŸŒ–
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers.quantization import quantize_embeddings
5
+ import faiss
6
+ from usearch.index import Index
7
+ import os
8
+ import spaces
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+ import torch
11
+ from threading import Thread
12
+
13
+ token = os.environ["HF_TOKEN"]
14
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
15
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
16
+ torch_dtype=torch.float16,
17
+ token=token)
18
+ tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
19
+ device = torch.device('cuda')
20
+ model = model.to(device)
21
+
22
+ # Load titles and texts
23
+ title_text_dataset = load_dataset(
24
+ "mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4
25
+ ).select_columns(["title", "text"])
26
+
27
+ # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
28
+ int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
29
+ binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(
30
+ "wikipedia_ubinary_faiss_50m.index"
31
+ )
32
+ binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary(
33
+ "wikipedia_ubinary_ivf_faiss_50m.index"
34
+ )
35
+
36
+ # Load the SentenceTransformer model for embedding the queries
37
+ model = SentenceTransformer(
38
+ "mixedbread-ai/mxbai-embed-large-v1",
39
+ prompts={
40
+ "retrieval": "Represent this sentence for searching relevant passages: ",
41
+ },
42
+ default_prompt_name="retrieval",
43
+ )
44
+
45
+
46
+ def search(
47
+ query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
48
+ ):
49
+ # 1. Embed the query as float32
50
+ query_embedding = model.encode(query)
51
+
52
+ # 2. Quantize the query to ubinary
53
+ query_embedding_ubinary = quantize_embeddings(
54
+ query_embedding.reshape(1, -1), "ubinary"
55
+ )
56
+
57
+ # 3. Search the binary index (either exact or approximate)
58
+ index = binary_ivf if use_approx else binary_index
59
+ _scores, binary_ids = index.search(
60
+ query_embedding_ubinary, top_k * rescore_multiplier
61
+ )
62
+ binary_ids = binary_ids[0]
63
+
64
+ # 4. Load the corresponding int8 embeddings
65
+ int8_embeddings = int8_view[binary_ids].astype(int)
66
+
67
+ # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
68
+ scores = query_embedding @ int8_embeddings.T
69
+
70
+ # 6. Sort the scores and return the top_k
71
+ indices = scores.argsort()[::-1][:top_k]
72
+ top_k_indices = binary_ids[indices]
73
+ top_k_scores = scores[indices]
74
+ top_k_titles, top_k_texts = zip(
75
+ *[
76
+ (title_text_dataset[idx]["title"], title_text_dataset[idx]["text"])
77
+ for idx in top_k_indices.tolist()
78
+ ]
79
+ )
80
+ df = {
81
+ "Score": [round(value, 2) for value in top_k_scores],
82
+ "Title": top_k_titles,
83
+ "Text": top_k_texts,
84
+ }
85
+
86
+ return df
87
+
88
+ def prepare_prompt(query, df):
89
+ prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
90
+ for data in df :
91
+ title = data["Title"]
92
+ text = data["Text"]
93
+ prompt+=f"Title: {title}, Text: {text}\n"
94
+ return prompt
95
+
96
+ @spaces.GPU
97
+ def talk(message, history):
98
+ df = search(message)
99
+ message = prepare_prompt(message,df)
100
+ resources = "\nRESOURCES:\n"
101
+ for title in df["Title"][:3] :
102
+ resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), "
103
+ chat = []
104
+ for item in history:
105
+ chat.append({"role": "user", "content": item[0]})
106
+ if item[1] is not None:
107
+ cleaned_past = item[1].split("\nRESOURCES:\n")[0]
108
+ chat.append({"role": "assistant", "content": cleaned_past})
109
+ chat.append({"role": "user", "content": message})
110
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
111
+ # Tokenize the messages string
112
+ model_inputs = tok([messages], return_tensors="pt").to(device)
113
+ streamer = TextIteratorStreamer(
114
+ tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
115
+ generate_kwargs = dict(
116
+ model_inputs,
117
+ streamer=streamer,
118
+ max_new_tokens=1024,
119
+ do_sample=True,
120
+ top_p=0.95,
121
+ top_k=1000,
122
+ temperature=0.75,
123
+ num_beams=1,
124
+ )
125
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
126
+ t.start()
127
+
128
+ # Initialize an empty string to store the generated text
129
+ partial_text = ""
130
+ for new_text in streamer:
131
+ partial_text += new_text
132
+ yield partial_text
133
+ partial_text+= resources
134
+ yield partial_text
135
+
136
+
137
+
138
+
139
+
140
+ TITLE = "RAG"
141
+
142
+ DESCRIPTION = """
143
+ ## Resources used to build this project
144
+ * https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb
145
+ * https://huggingface.co/spaces/sentence-transformers/quantized-retrieval
146
+ ## Retrival paramaters
147
+ ```python
148
+ top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
149
+ ```
150
+ ## Models
151
+ the models used in this space are :
152
+ * google/gemma-7b-it
153
+ * mixedbread-ai/wikipedia-data-en-2023-11
154
+ """
155
+
156
+ demo = gr.ChatInterface(fn=talk,
157
+ chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
158
+ theme="Soft",
159
+ examples=[["Write me a poem about Machine Learning."]],
160
+ title="Text Streaming")
161
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spaces
2
+ torch==2.2.0
3
+ git+https://github.com/huggingface/transformers/
4
+ git+https://github.com/tomaarsen/sentence-transformers@feat/quantization
5
+ usearch
6
+ faiss-cpu