eagle0504 commited on
Commit
afc3996
1 Parent(s): 4e18d60

Update helper/utils.py

Browse files
Files changed (1) hide show
  1. helper/utils.py +59 -0
helper/utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  from datetime import datetime
 
3
  from typing import Any, Dict, List, Tuple, Union
 
4
 
5
  import numpy as np
6
  import pandas as pd
@@ -205,6 +207,63 @@ def call_llama(prompt: str) -> str:
205
  return response.choices[0].message.content
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def quantize_to_kbit(arr: Union[np.ndarray, Any], k: int = 16) -> np.ndarray:
209
  """Converts an array to a k-bit representation by normalizing and scaling its values.
210
 
 
1
  import os
2
  from datetime import datetime
3
+ import json
4
  from typing import Any, Dict, List, Tuple, Union
5
+ import requests
6
 
7
  import numpy as np
8
  import pandas as pd
 
207
  return response.choices[0].message.content
208
 
209
 
210
+ def call_llama2(prompt: str, max_new_tokens: int = 50, temperature: float = 0.9) -> str:
211
+ """
212
+ Calls the Llama API to generate text based on a given prompt, controlling the length and randomness.
213
+
214
+ Args:
215
+ prompt (str): The prompt text to send to the Llama model for text generation.
216
+ max_new_tokens (int, optional): The maximum number of tokens that the model should generate. Defaults to 50.
217
+ temperature (float, optional): Controls the randomness of the output. Lower values make the model more deterministic.
218
+ A higher value increases randomness. Defaults to 0.9.
219
+
220
+ Returns:
221
+ str: The generated text response from the Llama model.
222
+
223
+ Raises:
224
+ Exception: If the API call fails and returns a non-200 status code, it raises an exception with the error details.
225
+ """
226
+ # API endpoint for the Llama model
227
+ api_url = "https://v6rkdcyir7.execute-api.us-east-1.amazonaws.com/beta"
228
+
229
+ # Configuration for the request body
230
+ json_body = {
231
+ "body": {
232
+ "inputs": f"<s>[INST] {prompt} [/INST]",
233
+ "parameters": {
234
+ "max_new_tokens": max_new_tokens,
235
+ "top_p": 0.9, # Fixed probability cutoff to select tokens with cumulative probability above this threshold
236
+ "temperature": temperature
237
+ }
238
+ }
239
+ }
240
+
241
+ # Headers to indicate that the payload is JSON
242
+ headers = {"Content-Type": "application/json"}
243
+
244
+ # Perform the POST request to the Llama API
245
+ response = requests.post(api_url, headers=headers, json=json_body)
246
+
247
+ # Parse the JSON response
248
+ response_body = response.json()['body']
249
+
250
+ # Convert the string response to a JSON object
251
+ body_list = json.loads(response_body)
252
+
253
+ # Extract the 'generated_text' from the first item in the list
254
+ generated_text = body_list[0]['generated_text']
255
+
256
+ # Separate the answer from the instruction
257
+ answer = generated_text.split("[/INST]")[-1].strip()
258
+
259
+ # Check the status code of the response
260
+ if response.status_code == 200:
261
+ return answer # Return the text generated by the model
262
+ else:
263
+ # Raise an exception if the API did not succeed
264
+ raise Exception(f"Error calling Llama API: {response.status_code}")
265
+
266
+
267
  def quantize_to_kbit(arr: Union[np.ndarray, Any], k: int = 16) -> np.ndarray:
268
  """Converts an array to a k-bit representation by normalizing and scaling its values.
269