code5ecure commited on
Commit
73eb94f
·
verified ·
1 Parent(s): c52be7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -22
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import gradio as gr
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
 
7
 
8
  # Disable torch.compile to avoid meta device issues
9
  torch._dynamo.config.suppress_errors = True
@@ -12,15 +13,21 @@ torch.set_default_dtype(torch.float32)
12
  # Set device explicitly
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- # Load LLaMA 2 Persian model and tokenizer
16
- model_name = "sinarashidi/llama-2-7b-chat-persian"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
 
 
 
 
 
19
 
20
  # Differential Privacy parameters
21
  epsilon = 1.0 # Privacy budget
22
  delta = 1e-5 # Privacy parameter
23
  sensitivity = 1.0 # Sensitivity of the query
 
24
 
25
  # Simple memory for conversation history
26
  conversation_history = []
@@ -48,14 +55,19 @@ def load_training_data():
48
  # Build RAG index
49
  def build_rag_index(texts):
50
  global embedder, index
51
- embedder = SentenceTransformer('xmanii/maux-gte-persian')
52
- embeddings = embedder.encode(texts, convert_to_tensor=True).cpu().numpy()
53
- dimension = embeddings.shape[1]
54
- index = faiss.IndexFlatL2(dimension)
55
- index.add(embeddings)
56
- return embedder, index
57
-
58
- # Fine-tune model with differential privacy (skipped to use only pretrained LLaMA 2)
 
 
 
 
 
59
  def train_model():
60
  global texts, embedder, index
61
  texts = load_training_data()
@@ -65,8 +77,7 @@ def train_model():
65
 
66
  # Build RAG index
67
  build_rag_index(texts)
68
-
69
- print("Skipping fine-tuning to use only pretrained LLaMA 2 model.")
70
 
71
  def add_noise(tensor, sensitivity, epsilon, delta):
72
  """Add Laplace noise for differential privacy."""
@@ -87,13 +98,15 @@ def chat(message, history):
87
  model.eval()
88
 
89
  # RAG retrieval
 
90
  if embedder and index:
91
- query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
92
- D, I = index.search(query_emb, k=3)
93
- retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
94
- context = "\n".join(retrieved)
95
- else:
96
- context = ""
 
97
 
98
  # Prepare prompt with context
99
  prompt = f"Context: {context}\nUser: {message}\nBot:"
@@ -113,19 +126,26 @@ def chat(message, history):
113
  )
114
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
 
 
 
 
 
 
 
 
116
  # Update conversation history
117
  update_model(message, response)
118
 
119
  return response
120
 
121
- # Train the model on startup (now only loads data and builds RAG index)
122
  train_model()
123
 
124
  # Gradio interface
125
  iface = gr.ChatInterface(
126
  fn=chat,
127
- title="LLaMA 2 Persian Chatbot with RAG",
128
- description="Chat with pretrained LLaMA 2 Persian model using training_data.txt as RAG knowledge base."
129
  )
130
 
131
  if __name__ == "__main__":
 
4
  import gradio as gr
5
  from sentence_transformers import SentenceTransformer
6
  import faiss
7
+ from bitsandbytes import quantize_model
8
 
9
  # Disable torch.compile to avoid meta device issues
10
  torch._dynamo.config.suppress_errors = True
 
13
  # Set device explicitly
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # Load LLaMA 3.2 1B model with 4-bit quantization
17
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
22
+ load_in_4bit=True, # Enable 4-bit quantization
23
+ device_map="auto" # Automatically map to available device
24
+ ).to(device)
25
 
26
  # Differential Privacy parameters
27
  epsilon = 1.0 # Privacy budget
28
  delta = 1e-5 # Privacy parameter
29
  sensitivity = 1.0 # Sensitivity of the query
30
+ apply_dp = False # Toggle differential privacy in inference (set to True to enable)
31
 
32
  # Simple memory for conversation history
33
  conversation_history = []
 
55
  # Build RAG index
56
  def build_rag_index(texts):
57
  global embedder, index
58
+ try:
59
+ embedder = SentenceTransformer('xmanii/maux-gte-persian', device='cpu') # Use CPU to save memory
60
+ embeddings = embedder.encode(texts, convert_to_tensor=True, batch_size=16).cpu().numpy() # Smaller batch size
61
+ dimension = embeddings.shape[1]
62
+ index = faiss.IndexFlatL2(dimension)
63
+ index.add(embeddings)
64
+ print("RAG index built successfully")
65
+ return embedder, index
66
+ except Exception as e:
67
+ print(f"Error building RAG index: {e}")
68
+ return None, None
69
+
70
+ # Initialize model and RAG (no fine-tuning)
71
  def train_model():
72
  global texts, embedder, index
73
  texts = load_training_data()
 
77
 
78
  # Build RAG index
79
  build_rag_index(texts)
80
+ print("Using pretrained LLaMA 3.2 1B model without fine-tuning.")
 
81
 
82
  def add_noise(tensor, sensitivity, epsilon, delta):
83
  """Add Laplace noise for differential privacy."""
 
98
  model.eval()
99
 
100
  # RAG retrieval
101
+ context = ""
102
  if embedder and index:
103
+ try:
104
+ query_emb = embedder.encode(message, convert_to_tensor=True).cpu().numpy()
105
+ D, I = index.search(query_emb, k=3)
106
+ retrieved = [texts[i] for i in I[0] if i >= 0 and i < len(texts)]
107
+ context = "\n".join(retrieved)
108
+ except Exception as e:
109
+ print(f"Error in RAG retrieval: {e}")
110
 
111
  # Prepare prompt with context
112
  prompt = f"Context: {context}\nUser: {message}\nBot:"
 
126
  )
127
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
128
 
129
+ # Apply differential privacy noise to logits (optional)
130
+ if apply_dp:
131
+ logits = model(**inputs).logits
132
+ noisy_logits = add_noise(logits, sensitivity, epsilon, delta)
133
+ response_ids = torch.argmax(noisy_logits, dim=-1)
134
+ response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
135
+
136
  # Update conversation history
137
  update_model(message, response)
138
 
139
  return response
140
 
141
+ # Initialize model and RAG (no fine-tuning)
142
  train_model()
143
 
144
  # Gradio interface
145
  iface = gr.ChatInterface(
146
  fn=chat,
147
+ title="LLaMA 3.2 1B Persian Chatbot with RAG",
148
+ description="Chat with pretrained LLaMA 3.2 1B model using training_data.txt as RAG knowledge base."
149
  )
150
 
151
  if __name__ == "__main__":