Chris4K commited on
Commit
af9f214
1 Parent(s): ef1764f

Create custom_agent.py

Browse files
Files changed (1) hide show
  1. custom_agent.py +38 -0
custom_agent.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_agent.py
2
+ import requests
3
+ import time
4
+ from transformers import Agent
5
+
6
+ class CustomHfAgent(Agent):
7
+ def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None):
8
+ super().__init__(
9
+ chat_prompt_template=chat_prompt_template,
10
+ run_prompt_template=run_prompt_template,
11
+ additional_tools=additional_tools,
12
+ )
13
+ self.url_endpoint = url_endpoint
14
+ self.token = token
15
+ self.input_params = input_params
16
+
17
+ def generate_one(self, prompt, stop):
18
+ headers = {"Authorization": self.token}
19
+ max_new_tokens = self.input_params.get("max_new_tokens", 192)
20
+ parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True}
21
+ inputs = {
22
+ "inputs": prompt,
23
+ "parameters": parameters,
24
+ }
25
+ response = requests.post(self.url_endpoint, json=inputs, headers=headers)
26
+
27
+ if response.status_code == 429:
28
+ log_response("Getting rate-limited, waiting a tiny bit before trying again.")
29
+ time.sleep(1)
30
+ return self._generate_one(prompt)
31
+ elif response.status_code != 200:
32
+ raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}")
33
+ log_response(response)
34
+ result = response.json()[0]["generated_text"]
35
+ for stop_seq in stop:
36
+ if result.endswith(stop_seq):
37
+ return result[: -len(stop_seq)]
38
+ return result