sagar007 commited on
Commit
8a9a6c3
·
verified ·
1 Parent(s): f4f3cd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -60
app.py CHANGED
@@ -1,32 +1,49 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import spaces
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
7
  from datetime import datetime
 
8
 
9
- # Initialize model and tokenizer
10
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
12
  tokenizer.pad_token = tokenizer.eos_token
13
 
14
- # Modified model loading for CPU
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
17
- device_map="cpu", # Changed to CPU
 
18
  low_cpu_mem_usage=True,
19
- torch_dtype=torch.float32 # Changed to float32 for CPU
20
  )
21
 
22
- def get_web_results(query, max_results=5): # Increased to 5 for better context
 
 
 
 
23
  """Get web search results using DuckDuckGo"""
24
  try:
25
  with DDGS() as ddgs:
26
  results = list(ddgs.text(query, max_results=max_results))
27
  return [{
28
  "title": result.get("title", ""),
29
- "snippet": result["body"],
30
  "url": result["href"],
31
  "date": result.get("published", "")
32
  } for result in results]
@@ -34,19 +51,10 @@ def get_web_results(query, max_results=5): # Increased to 5 for better context
34
  return []
35
 
36
  def format_prompt(query, context):
37
- """Format the prompt with web context"""
38
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
39
- context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
40
- return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
41
- Current Time: {current_time}
42
-
43
- Query: {query}
44
-
45
- Web Context:
46
- {context_lines}
47
-
48
- Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc.
49
- Answer:"""
50
 
51
  def format_sources(web_results):
52
  """Format sources with more details"""
@@ -71,69 +79,81 @@ def format_sources(web_results):
71
  return sources_html
72
 
73
  def generate_answer(prompt):
74
- """Generate answer using the DeepSeek model"""
75
- inputs = tokenizer(
76
- prompt,
77
- return_tensors="pt",
78
- padding=True,
79
- truncation=True,
80
- max_length=256, # Reduced max length for CPU
81
- return_attention_mask=True
82
- ) # Removed .to(model.device) since we're using CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- outputs = model.generate(
85
- inputs.input_ids,
86
- attention_mask=inputs.attention_mask,
87
- max_new_tokens=128, # Reduced for faster generation on CPU
88
- temperature=0.7,
89
- top_p=0.95,
90
- pad_token_id=tokenizer.eos_token_id,
91
- do_sample=True,
92
- early_stopping=True,
93
- num_beams=1 # Reduced beam search for faster generation
94
- )
95
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
96
 
97
  def process_query(query, history):
98
- """Process user query with streaming effect"""
99
  try:
100
  if history is None:
101
  history = []
102
-
103
  # Get web results first
104
  web_results = get_web_results(query)
105
  sources_html = format_sources(web_results)
106
 
107
- current_history = history + [[query, "*Searching...*"]]
108
  yield {
109
- answer_output: gr.Markdown("*Searching the web...*"),
110
  sources_output: gr.HTML(sources_html),
111
- search_btn: gr.Button("Searching...", interactive=False),
112
- chat_history_display: current_history
113
  }
114
 
115
- # Generate answer
116
  prompt = format_prompt(query, web_results)
117
  answer = generate_answer(prompt)
118
- final_answer = answer.split("Answer:")[-1].strip()
119
 
120
- updated_history = history + [[query, final_answer]]
 
121
  yield {
122
- answer_output: gr.Markdown(final_answer),
123
  sources_output: gr.HTML(sources_html),
124
  search_btn: gr.Button("Search", interactive=True),
125
- chat_history_display: updated_history
126
  }
127
- except Exception as e:
128
- error_message = str(e)
129
- if "GPU quota" in error_message:
130
- error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
131
 
 
 
132
  yield {
133
- answer_output: gr.Markdown(f"Error: {error_message}"),
134
- sources_output: gr.HTML(sources_html),
135
  search_btn: gr.Button("Search", interactive=True),
136
- chat_history_display: history + [[query, f"*Error: {error_message}*"]]
137
  }
138
 
139
  # Update the CSS for better contrast and readability
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
  import spaces
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
7
  from datetime import datetime
8
+ import gc # For manual garbage collection
9
 
10
+ # Initialize model and tokenizer with optimizations
11
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
12
+
13
+ # Load config first to set optimal parameters
14
+ config = AutoConfig.from_pretrained(model_name)
15
+ config.use_cache = True # Enable KV-caching for faster inference
16
+
17
+ # Initialize tokenizer with optimizations
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ model_name,
20
+ model_max_length=256, # Reduced for faster processing
21
+ padding_side="left",
22
+ truncation_side="left",
23
+ )
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
+ # Load model with optimizations
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
+ config=config,
30
+ device_map="cpu",
31
  low_cpu_mem_usage=True,
32
+ torch_dtype=torch.float32,
33
  )
34
 
35
+ # Enable model optimizations
36
+ model.eval() # Set to evaluation mode
37
+ torch.set_num_threads(4) # Limit CPU threads for better performance
38
+
39
+ def get_web_results(query, max_results=3): # Reduced max results
40
  """Get web search results using DuckDuckGo"""
41
  try:
42
  with DDGS() as ddgs:
43
  results = list(ddgs.text(query, max_results=max_results))
44
  return [{
45
  "title": result.get("title", ""),
46
+ "snippet": result["body"][:200], # Limit snippet length
47
  "url": result["href"],
48
  "date": result.get("published", "")
49
  } for result in results]
 
51
  return []
52
 
53
  def format_prompt(query, context):
54
+ """Format the prompt with web context - optimized version"""
55
+ context_lines = '\n'.join([f'[{i+1}] {res["snippet"]}'
56
+ for i, res in enumerate(context)])
57
+ return f"""Answer this query using the context: {query}\n\nContext:\n{context_lines}\n\nAnswer:"""
 
 
 
 
 
 
 
 
 
58
 
59
  def format_sources(web_results):
60
  """Format sources with more details"""
 
79
  return sources_html
80
 
81
  def generate_answer(prompt):
82
+ """Generate answer using the DeepSeek model - optimized version"""
83
+ try:
84
+ # Clear CUDA cache and garbage collect
85
+ if torch.cuda.is_available():
86
+ torch.cuda.empty_cache()
87
+ gc.collect()
88
+
89
+ inputs = tokenizer(
90
+ prompt,
91
+ return_tensors="pt",
92
+ padding=True,
93
+ truncation=True,
94
+ max_length=256,
95
+ return_attention_mask=True
96
+ )
97
+
98
+ with torch.no_grad(): # Disable gradient calculation
99
+ outputs = model.generate(
100
+ inputs.input_ids,
101
+ attention_mask=inputs.attention_mask,
102
+ max_new_tokens=100, # Further reduced for speed
103
+ temperature=0.7,
104
+ top_p=0.95,
105
+ pad_token_id=tokenizer.eos_token_id,
106
+ do_sample=True,
107
+ num_beams=1,
108
+ early_stopping=True,
109
+ no_repeat_ngram_size=3,
110
+ length_penalty=1.0
111
+ )
112
+
113
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
+ return response.split('Answer:')[-1].strip()
115
 
116
+ except Exception as e:
117
+ return f"Error generating response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
118
 
119
  def process_query(query, history):
120
+ """Process user query with optimized streaming effect"""
121
  try:
122
  if history is None:
123
  history = []
124
+
125
  # Get web results first
126
  web_results = get_web_results(query)
127
  sources_html = format_sources(web_results)
128
 
129
+ # Show searching status
130
  yield {
131
+ answer_output: gr.Markdown("*Searching and generating response...*"),
132
  sources_output: gr.HTML(sources_html),
133
+ search_btn: gr.Button("Please wait...", interactive=False),
134
+ chat_history_display: history + [[query, "*Processing...*"]]
135
  }
136
 
137
+ # Generate answer with timeout protection
138
  prompt = format_prompt(query, web_results)
139
  answer = generate_answer(prompt)
 
140
 
141
+ # Update with final answer
142
+ final_history = history + [[query, answer]]
143
  yield {
144
+ answer_output: gr.Markdown(answer),
145
  sources_output: gr.HTML(sources_html),
146
  search_btn: gr.Button("Search", interactive=True),
147
+ chat_history_display: final_history
148
  }
 
 
 
 
149
 
150
+ except Exception as e:
151
+ error_msg = f"Error: {str(e)}"
152
  yield {
153
+ answer_output: gr.Markdown(error_msg),
154
+ sources_output: gr.HTML("<div>Error fetching sources</div>"),
155
  search_btn: gr.Button("Search", interactive=True),
156
+ chat_history_display: history + [[query, error_msg]]
157
  }
158
 
159
  # Update the CSS for better contrast and readability