Shreyas094 commited on
Commit
8b5e7fa
·
verified ·
1 Parent(s): f3cc462

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -107
app.py CHANGED
@@ -1,27 +1,18 @@
1
  import os
2
  import logging
3
- import json
4
- import time
5
  import gradio as gr
6
  from huggingface_hub import InferenceClient
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
  from langchain.schema import Document
10
  from duckduckgo_search import DDGS
11
- from dotenv import load_dotenv
12
- from functools import lru_cache
13
- from tenacity import retry, stop_after_attempt, wait_fixed
14
-
15
- # Load environment variables
16
- load_dotenv()
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
 
22
  # Environment variables and configurations
23
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
24
- logger.info(f"Using Hugging Face token: {HUGGINGFACE_TOKEN[:4]}...{HUGGINGFACE_TOKEN[-4:] if HUGGINGFACE_TOKEN else 'Not Set'}")
25
 
26
  MODELS = [
27
  "mistralai/Mistral-7B-Instruct-v0.3",
@@ -33,8 +24,7 @@ MODELS = [
33
  "google/gemma-2-27b-it"
34
  ]
35
 
36
- FALLBACK_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
37
-
38
  DEFAULT_SYSTEM_PROMPT = """You are a world-class financial AI assistant, capable of complex reasoning and reflection.
39
  Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags.
40
  Providing comprehensive and accurate information based on web search results is essential.
@@ -42,47 +32,36 @@ Your goal is to synthesize the given context into a coherent and detailed respon
42
  Please ensure that your response is well-structured and factual.
43
  If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."""
44
 
45
- class WebSearcher:
46
- def __init__(self):
47
- self.ddgs = DDGS()
48
-
49
- @lru_cache(maxsize=100)
50
- def search(self, query, max_results=5):
51
- try:
52
- results = list(self.ddgs.text(query, max_results=max_results))
53
- logger.info(f"Search completed for query: {query}")
54
- return results
55
- except Exception as e:
56
- logger.error(f"Error during DuckDuckGo search: {str(e)}")
57
- return []
58
-
59
- @lru_cache(maxsize=1)
60
  def get_embeddings():
61
  return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
62
 
 
 
 
 
 
 
 
 
 
 
63
  def create_web_search_vectors(search_results):
64
  embed = get_embeddings()
65
- documents = [
66
- Document(
67
- page_content=f"{result['title']}\n{result['body']}\nSource: {result['href']}",
68
- metadata={"source": result['href']}
69
- )
70
- for result in search_results if 'body' in result
71
- ]
72
- logger.info(f"Created vectors for {len(documents)} search results.")
73
  return FAISS.from_documents(documents, embed)
74
 
75
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
76
- def make_api_call(client, api_params):
77
- return client.chat_completion(**api_params)
78
-
79
- def get_response_with_search(query, system_prompt, model, use_embeddings, history, num_calls=3, temperature=0.2):
80
- searcher = WebSearcher()
81
- search_results = searcher.search(query)
82
 
83
  if not search_results:
84
- logger.warning(f"No web search results found for query: {query}")
85
- return "No web search results available. Please try again.", ""
 
86
 
87
  sources = [result['href'] for result in search_results if 'href' in result]
88
  source_list_str = "\n".join(sources)
@@ -95,80 +74,105 @@ def get_response_with_search(query, system_prompt, model, use_embeddings, histor
95
  else:
96
  context = "\n".join([f"{result['title']}\n{result['body']}" for result in search_results])
97
 
98
- logger.info(f"Context created for query: {query}")
99
 
100
- chat_history = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history])
101
- user_message = f"""Chat history:
102
- {chat_history}
103
-
104
- Using the following context from web search results:
105
  {context}
106
 
107
  Write a detailed and complete research document that fulfills the following user request: '{query}'."""
108
 
109
- client = InferenceClient(model, token=HUGGINGFACE_TOKEN)
110
  full_response = ""
 
 
 
 
 
 
 
 
 
 
111
  try:
112
- for _ in range(num_calls):
113
- api_params = {
114
- "messages": [
115
- {"role": "system", "content": system_prompt},
116
- {"role": "user", "content": user_message}
117
- ],
118
- "max_tokens": 3000,
119
- "temperature": temperature,
120
- "top_p": 0.8,
121
- }
122
- logger.info(f"Sending request to API with params: {json.dumps(api_params, indent=2, default=str)}")
123
- response = make_api_call(client, api_params)
124
- logger.info(f"Raw response from model: {response}")
125
-
126
- if isinstance(response, dict):
127
- if 'generated_text' in response:
128
- full_response += response['generated_text']
129
- elif 'choices' in response and len(response['choices']) > 0:
130
- if isinstance(response['choices'][0], dict) and 'message' in response['choices'][0]:
131
- full_response += response['choices'][0]['message'].get('content', '')
132
- elif isinstance(response['choices'][0], str):
133
- full_response += response['choices'][0]
134
- elif hasattr(response, 'generated_text'):
135
- full_response += response.generated_text
136
- elif hasattr(response, 'content'):
137
- full_response += response.content
138
- else:
139
- logger.error(f"Unexpected response format from the model: {type(response)}")
140
- return "Unexpected response format from the model. Please try again.", ""
141
 
142
- time.sleep(1) # Add a 1-second delay between API calls
143
- except Exception as e:
144
- logger.error(f"Error in get_response_with_search: {str(e)}")
145
- logger.info(f"Attempting fallback to {FALLBACK_MODEL}")
146
- client = InferenceClient(FALLBACK_MODEL, token=HUGGINGFACE_TOKEN)
147
- # Retry with fallback model (you can implement retry logic here)
148
- return f"An error occurred while processing your request: {str(e)}", ""
149
 
150
  if not full_response:
151
- logger.warning("No response generated from the model")
152
- return "No response generated from the model.", ""
153
- else:
154
- return f"{full_response}\n\nSources:\n{source_list_str}", ""
155
-
156
- def respond(message, system_prompt, history, model, temperature, num_calls, use_embeddings):
157
- logger.info(f"Respond function called with message: {message}")
158
- logger.info(f"User Query: {message}")
159
- logger.info(f"Model Used: {model}")
160
- logger.info(f"Temperature: {temperature}")
161
- logger.info(f"Number of API Calls: {num_calls}")
162
- logger.info(f"Use Embeddings: {use_embeddings}")
163
- logger.info(f"System Prompt: {system_prompt}")
164
- logger.info(f"History: {history}")
 
 
 
 
 
165
 
166
  try:
167
- main_content, sources = get_response_with_search(message, system_prompt, model, use_embeddings, history, num_calls=num_calls, temperature=temperature)
168
- return main_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  except Exception as e:
170
- logger.error(f"Error in respond function: {str(e)}")
171
- return f"An error occurred: {str(e)}"
172
 
173
  css = """
174
  /* Fine-tune chatbox size */
@@ -182,6 +186,7 @@ css = """
182
  }
183
  """
184
 
 
185
  def create_gradio_interface():
186
  custom_placeholder = "Enter your question here for web search."
187
 
@@ -232,4 +237,4 @@ def create_gradio_interface():
232
 
233
  if __name__ == "__main__":
234
  demo = create_gradio_interface()
235
- demo.launch(share=True)
 
1
  import os
2
  import logging
3
+ import asyncio
 
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.vectorstores import FAISS
8
  from langchain.schema import Document
9
  from duckduckgo_search import DDGS
 
 
 
 
 
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
13
 
14
  # Environment variables and configurations
15
+ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
 
16
 
17
  MODELS = [
18
  "mistralai/Mistral-7B-Instruct-v0.3",
 
24
  "google/gemma-2-27b-it"
25
  ]
26
 
27
+ # Default system message template
 
28
  DEFAULT_SYSTEM_PROMPT = """You are a world-class financial AI assistant, capable of complex reasoning and reflection.
29
  Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags.
30
  Providing comprehensive and accurate information based on web search results is essential.
 
32
  Please ensure that your response is well-structured and factual.
33
  If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags."""
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def get_embeddings():
36
  return HuggingFaceEmbeddings(model_name="sentence-transformers/stsb-roberta-large")
37
 
38
+ def duckduckgo_search(query):
39
+ try:
40
+ with DDGS() as ddgs:
41
+ results = ddgs.text(query, max_results=5)
42
+ logging.info(f"Search completed for query: {query}")
43
+ return results
44
+ except Exception as e:
45
+ logging.error(f"Error during DuckDuckGo search: {str(e)}")
46
+ return []
47
+
48
  def create_web_search_vectors(search_results):
49
  embed = get_embeddings()
50
+ documents = []
51
+ for result in search_results:
52
+ if 'body' in result:
53
+ content = f"{result['title']}\n{result['body']}\nSource: {result['href']}"
54
+ documents.append(Document(page_content=content, metadata={"source": result['href']}))
55
+ logging.info(f"Created vectors for {len(documents)} search results.")
 
 
56
  return FAISS.from_documents(documents, embed)
57
 
58
+ async def get_response_with_search(query, system_prompt, model, use_embeddings, history=None, num_calls=3, temperature=0.2):
59
+ search_results = duckduckgo_search(query)
 
 
 
 
 
60
 
61
  if not search_results:
62
+ logging.warning(f"No web search results found for query: {query}")
63
+ yield "No web search results available. Please try again.", ""
64
+ return
65
 
66
  sources = [result['href'] for result in search_results if 'href' in result]
67
  source_list_str = "\n".join(sources)
 
74
  else:
75
  context = "\n".join([f"{result['title']}\n{result['body']}" for result in search_results])
76
 
77
+ logging.info(f"Context created for query: {query}")
78
 
79
+ user_message = f"""Using the following context from web search results:
 
 
 
 
80
  {context}
81
 
82
  Write a detailed and complete research document that fulfills the following user request: '{query}'."""
83
 
84
+ client = InferenceClient(model, token=huggingface_token)
85
  full_response = ""
86
+
87
+ messages = [
88
+ {"role": "system", "content": system_prompt},
89
+ {"role": "user", "content": user_message}
90
+ ]
91
+
92
+ # Include chat history if provided
93
+ if history:
94
+ messages = history + messages
95
+
96
  try:
97
+ for call in range(num_calls):
98
+ try:
99
+ for response in client.chat_completion(
100
+ messages=messages,
101
+ max_tokens=6000,
102
+ temperature=temperature,
103
+ stream=True,
104
+ top_p=0.8,
105
+ ):
106
+ if isinstance(response, dict) and "choices" in response:
107
+ for choice in response["choices"]:
108
+ if "delta" in choice and "content" in choice["delta"]:
109
+ chunk = choice["delta"]["content"]
110
+ full_response += chunk
111
+ yield full_response, ""
112
+ else:
113
+ logging.error("Unexpected response format or missing attributes in the response object.")
114
+ break
115
+ except Exception as e:
116
+ logging.error(f"Error in API call {call + 1}: {str(e)}")
117
+ if "422 Client Error" in str(e):
118
+ logging.warning("Received 422 Client Error. Adjusting request parameters.")
119
+ # You might want to adjust parameters here, e.g., reduce max_tokens
120
+ yield f"An error occurred during API call {call + 1}. Retrying...", ""
 
 
 
 
 
121
 
122
+ # Add a small delay between API calls
123
+ await asyncio.sleep(1) # 1 second delay
124
+
125
+ except asyncio.CancelledError:
126
+ logging.warning("The operation was cancelled.")
127
+ yield "The operation was cancelled. Please try again.", ""
 
128
 
129
  if not full_response:
130
+ logging.warning("No response generated from the model")
131
+ yield "No response generated from the model.", ""
132
+
133
+ yield f"{full_response}\n\nSources:\n{source_list_str}", ""
134
+
135
+ async def respond(message, system_prompt, history, model, temperature, num_calls, use_embeddings):
136
+ logging.info(f"User Query: {message}")
137
+ logging.info(f"Model Used: {model}")
138
+ logging.info(f"Temperature: {temperature}")
139
+ logging.info(f"Number of API Calls: {num_calls}")
140
+ logging.info(f"Use Embeddings: {use_embeddings}")
141
+ logging.info(f"System Prompt: {system_prompt}")
142
+
143
+ # Convert gradio history to the format expected by get_response_with_search
144
+ chat_history = []
145
+ for human, assistant in history:
146
+ chat_history.append({"role": "user", "content": human})
147
+ if assistant:
148
+ chat_history.append({"role": "assistant", "content": assistant})
149
 
150
  try:
151
+ full_response = ""
152
+ async for main_content, sources in get_response_with_search(
153
+ message,
154
+ system_prompt,
155
+ model,
156
+ use_embeddings,
157
+ history=chat_history,
158
+ num_calls=num_calls,
159
+ temperature=temperature
160
+ ):
161
+ # Yield only the new content
162
+ new_content = main_content[len(full_response):]
163
+ full_response = main_content
164
+ yield new_content
165
+
166
+ # Yield the sources as a separate message
167
+ if sources:
168
+ yield f"\n\nSources:\n{sources}"
169
+
170
+ except asyncio.CancelledError:
171
+ logging.warning("The operation was cancelled.")
172
+ yield "The operation was cancelled. Please try again."
173
  except Exception as e:
174
+ logging.error(f"Error in respond function: {str(e)}")
175
+ yield f"An error occurred: {str(e)}"
176
 
177
  css = """
178
  /* Fine-tune chatbox size */
 
186
  }
187
  """
188
 
189
+ # Gradio interface setup
190
  def create_gradio_interface():
191
  custom_placeholder = "Enter your question here for web search."
192
 
 
237
 
238
  if __name__ == "__main__":
239
  demo = create_gradio_interface()
240
+ demo.launch(share=True)