alx-d commited on
Commit
9236d0a
·
verified ·
1 Parent(s): 01f968f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +22 -8
advanced_rag.py CHANGED
@@ -150,15 +150,29 @@ class ElevatedRagChain:
150
  if not hf_api_token:
151
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
152
  client = InferenceClient(token=hf_api_token, timeout=240)
 
 
153
  def remote_generate(prompt: str) -> str:
154
- response = client.text_generation(
155
- prompt,
156
- model=repo_id,
157
- temperature=self.temperature,
158
- top_p=self.top_p,
159
- repetition_penalty=1.1,
160
- wait_for_model=True,
161
- )
 
 
 
 
 
 
 
 
 
 
 
 
162
  return response
163
  from langchain.llms.base import LLM
164
  class RemoteLLM(LLM):
 
150
  if not hf_api_token:
151
  raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
152
  client = InferenceClient(token=hf_api_token, timeout=240)
153
+
154
+ from huggingface_hub.utils._errors import HfHubHTTPError
155
  def remote_generate(prompt: str) -> str:
156
+ max_retries = 5
157
+ backoff = 2 # start with 2 seconds
158
+ response = None
159
+ for attempt in range(max_retries):
160
+ try:
161
+ response = client.text_generation(
162
+ prompt,
163
+ model=repo_id,
164
+ temperature=self.temperature,
165
+ top_p=self.top_p,
166
+ repetition_penalty=1.1
167
+ )
168
+ return response
169
+ except HfHubHTTPError as e:
170
+ debug_print(f"Attempt {attempt+1} failed with error: {e}")
171
+ # if this is the last attempt, re-raise the error
172
+ if attempt == max_retries - 1:
173
+ raise
174
+ time.sleep(backoff)
175
+ backoff *= 2 # exponential backoff
176
  return response
177
  from langchain.llms.base import LLM
178
  class RemoteLLM(LLM):