alisharifi commited on
Commit
1c54a6f
·
1 Parent(s): 99036ae
Files changed (7) hide show
  1. .gitattributes +3 -0
  2. README.md +5 -4
  3. app.py +175 -0
  4. idx_item_mapping.pkl +3 -0
  5. image.index +3 -0
  6. requirements.txt +13 -0
  7. text.index +3 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ image.index filter=lfs diff=lfs merge=lfs -text
37
+ text.index filter=lfs diff=lfs merge=lfs -text
38
+ idx_item_mapping.pkl filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Tourist Attraction Rag
3
- emoji: 🦀
4
  colorFrom: red
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Tourist Attractions Multimodal Rag
3
+ emoji: 👀
4
  colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.46.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.10.11
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os
4
+ import pickle
5
+
6
+ import faiss
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from sentence_transformers import SentenceTransformer
12
+ from transformers import AutoImageProcessor, AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
13
+ from tqdm import tqdm
14
+ from datasets import load_dataset
15
+ from hazm import Normalizer
16
+
17
+
18
+ DATASET_NAME = 'alisharifi/tourist-attractions-text-image-data'
19
+ TEST_DATA_NAME = 'alisharifi/tourist-attractions-test-data'
20
+
21
+ dataset = load_dataset(DATASET_NAME, streaming=True)
22
+ test_data_name = load_dataset(TEST_DATA_NAME, streaming=True)
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ vision_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
27
+ vision_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
28
+
29
+ language_model = SentenceTransformer("xmanii/maux-gte-persian", trust_remote_code=True).to(device)
30
+
31
+ quantization_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_use_double_quant=True,
34
+ bnb_4bit_quant_type="nf4",
35
+ )
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ "universitytehran/PersianMind-v1.0",
38
+ quantization_config=quantization_config,
39
+ device_map="auto"
40
+ )
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ "universitytehran/PersianMind-v1.0",
43
+ )
44
+
45
+ normalizer = Normalizer()
46
+
47
+ language_model.eval()
48
+ vision_model.eval()
49
+
50
+ # Load FAISS indices
51
+ text_index = faiss.read_index("text.index")
52
+ image_index = faiss.read_index("image.index")
53
+
54
+ # Load the index-item mapping
55
+ with open("idx_item_mapping.pkl", "rb") as f:
56
+ idx_item_mapping = pickle.load(f)
57
+
58
+ print("FAISS indices and index-item mapping loaded.")
59
+
60
+
61
+ def search_by_text(query_text, k=5):
62
+ """
63
+ Searches the database for the top k items most similar to the query text.
64
+
65
+ Args:
66
+ query_text: The text query.
67
+ k: The number of top similar items to return.
68
+
69
+ Returns:
70
+ A list of dictionaries, where each dictionary contains the item details
71
+ for the top k similar items.
72
+ """
73
+ normalized_query = normalizer.normalize(query_text)
74
+ query_embedding = language_model.encode(normalized_query)
75
+
76
+ query_embedding_np = query_embedding[np.newaxis, :]
77
+ faiss.normalize_L2(query_embedding_np)
78
+
79
+ distances, indices = text_index.search(query_embedding_np, 100)
80
+
81
+ unique_texts = set()
82
+ results = []
83
+ for idx in indices[0]:
84
+ text = idx_item_mapping[idx]
85
+ if text not in unique_texts:
86
+ unique_texts.add(text)
87
+ results.append(text)
88
+ if len(results) == k:
89
+ break
90
+
91
+ return results
92
+
93
+ def search_by_image(query_image, k=5):
94
+ """
95
+ Searches the database for the top k items most similar to the query text.
96
+
97
+ Args:
98
+ query_text: The text query.
99
+ k: The number of top similar items to return.
100
+
101
+ Returns:
102
+ A list of dictionaries, where each dictionary contains the item details
103
+ for the top k similar items.
104
+ """
105
+ inputs = vision_processor(images=query_image, return_tensors="pt").to(device) # Move image inputs to device
106
+ with torch.no_grad():
107
+ outputs = vision_model(**inputs)
108
+ image_embedding_np = outputs[0].mean(dim=1)[0].cpu().numpy()
109
+
110
+
111
+ query_embedding_np = image_embedding_np[np.newaxis, :]
112
+ faiss.normalize_L2(query_embedding_np)
113
+
114
+ # Search the FAISS index
115
+ distances, indices = image_index.search(query_embedding_np, 100)
116
+
117
+ # Get the top k items using the indices and the mapping
118
+ unique_texts = set()
119
+ results = []
120
+ for idx in indices[0]:
121
+ text = idx_item_mapping[idx]
122
+ if text not in unique_texts:
123
+ unique_texts.add(text)
124
+ results.append(text)
125
+ if len(results) == k:
126
+ break
127
+
128
+ return results
129
+
130
+
131
+ def rag_pipeline(question, image=None):
132
+ """
133
+ Runs the RAG pipeline with the given question and optional image.
134
+
135
+ Args:
136
+ question: The text question.
137
+ image: Optional image input.
138
+
139
+ Returns:
140
+ The generated answer from the language model.
141
+ """
142
+ retrieved_items = []
143
+ if image is not None:
144
+ retrieved_items.extend(search_by_image(image))
145
+ retrieved_items.extend(search_by_text(question))
146
+
147
+ TEMPLATE = "{context}\nYou: {prompt}\nPersianMind: "
148
+ CONTEXT = '\n'.join(retrieved_items)
149
+ PROMPT = '\n'.join([
150
+ question,
151
+ 'به این سوال به فارسی جواب بده.'
152
+ ])
153
+
154
+ model_input = TEMPLATE.format(context=CONTEXT, prompt=PROMPT)
155
+ input_tokens = tokenizer(model_input, return_tensors="pt")
156
+ input_tokens = input_tokens.to(device)
157
+ generate_ids = model.generate(**input_tokens, max_new_tokens=200, do_sample=False, repetition_penalty=1.1)
158
+ model_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
159
+
160
+ return model_output[len(model_input):]
161
+
162
+
163
+ iface = gr.Interface(
164
+ fn=rag_pipeline,
165
+ inputs=[
166
+ gr.Textbox(label="Your Question"),
167
+ gr.Image(type="pil", label="Optional Image")
168
+ ],
169
+ outputs=gr.Textbox(label="Answer"),
170
+ title="Tourist Attraction RAG Pipeline",
171
+ description="Ask a question about tourist attractions and optionally provide an image."
172
+ )
173
+
174
+
175
+ iface.launch(debug=True)
idx_item_mapping.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f91dd8e55ba36208407b204b38725beb034febc212dbf4eb40c0bbbc6e31e53
3
+ size 1486498
image.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d469ba218d61ae40b215f2074b73c66dcbc1e6380f5a416a661be12ced229470
3
+ size 6703149
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ faiss-cpu
2
+ gradio
3
+ numpy
4
+ torch
5
+ pillow
6
+ sentence_transformers
7
+ transformers
8
+ datasets
9
+ hazm
10
+ tqdm
11
+ bitsandbytes
12
+ accelerate
13
+ sentencepiece
text.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8519dad63618a4cab940e95b55639196220236f2de36a6fec8aeb1bca385660
3
+ size 6703149