njwright92 commited on
Commit
0b76d3c
·
verified ·
1 Parent(s): fa89c1e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +59 -27
handler.py CHANGED
@@ -1,25 +1,42 @@
1
  from typing import Dict, List, Any
2
  from llama_cpp import Llama
3
  import gemma_tools
 
4
 
5
  MAX_TOKENS = 1000
6
 
7
 
8
- class EndpointHandler():
9
- def __init__(self, model_dir=None):
10
- if model_dir:
11
-
12
- # Initialize the Llama model directly
13
 
14
- self.model = Llama(
15
- # Adjust the path if necessary
16
- model_path=f"{model_dir}/ComicBot_v.2-gguf",
17
- n_ctx=MAX_TOKENS,
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
21
- # Extract and validate arguments from the data
 
22
 
 
 
 
 
23
  args_check = gemma_tools.get_args_or_none(data)
24
 
25
  if not args_check[0]: # If validation failed
@@ -29,26 +46,25 @@ class EndpointHandler():
29
  "description": args_check.get("description", "Validation error in arguments")
30
  }]
31
 
32
- args = args_check # If validation passed, args are in args_check
 
33
 
34
- # Define the formatting template
35
- fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{inputs} <endofturn>\n<startofturn>model"
36
 
37
  try:
38
- formatted_prompt = fmat.format(**args)
39
-
40
  except Exception as e:
41
-
42
  return [{
43
  "status": "error",
44
  "reason": "Invalid format",
45
  "detail": str(e)
46
  }]
47
 
 
48
  max_length = data.get("max_length", 212)
49
  try:
50
  max_length = int(max_length)
51
-
52
  except ValueError:
53
  return [{
54
  "status": "error",
@@ -56,16 +72,32 @@ class EndpointHandler():
56
  "detail": "max_length was not a valid integer"
57
  }]
58
 
59
- res = self.model(
60
- formatted_prompt,
61
- temperature=args["temperature"],
62
- top_p=args["top_p"],
63
- top_k=args["top_k"],
64
- max_tokens=max_length
65
- )
 
 
 
 
 
 
 
 
66
 
67
  return [{
68
  "status": "success",
69
- # Assuming Llama's response format
70
- "response": res['choices'][0]['text']
71
  }]
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  from llama_cpp import Llama
3
  import gemma_tools
4
+ import os
5
 
6
  MAX_TOKENS = 1000
7
 
8
 
9
+ class EndpointHandler:
10
+ def __init__(self, model_dir: str = None):
11
+ """
12
+ Initialize the EndpointHandler with the path to the model directory.
 
13
 
14
+ :param model_dir: Path to the directory containing the model file.
15
+ """
16
+ if model_dir:
17
+ # Update the model filename to match the one in your repository
18
+ model_path = os.path.join(
19
+ model_dir, "comic_mistral-v5.2.q5_0.gguf")
20
+ if not os.path.exists(model_path):
21
+ raise FileNotFoundError(
22
+ f"The model file was not found at {model_path}")
23
+
24
+ try:
25
+ self.model = Llama(
26
+ model_path=model_path,
27
+ n_ctx=MAX_TOKENS, # Use n_ctx for context size in llama_cpp
28
+ )
29
+ except Exception as e:
30
+ raise RuntimeError(f"Failed to load the model: {e}")
31
 
32
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
+ """
34
+ Handle incoming requests for model inference.
35
 
36
+ :param data: Dictionary containing input data and parameters for the model.
37
+ :return: A list with a dictionary containing the status and response or error details.
38
+ """
39
+ # Extract and validate arguments from the data
40
  args_check = gemma_tools.get_args_or_none(data)
41
 
42
  if not args_check[0]: # If validation failed
 
46
  "description": args_check.get("description", "Validation error in arguments")
47
  }]
48
 
49
+ # If validation passed, args are in the second element of the tuple
50
+ args = args_check[1]
51
 
52
+ # Define the formatting template for the prompt
53
+ prompt_format = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{inputs} <endofturn>\n<startofturn>model"
54
 
55
  try:
56
+ formatted_prompt = prompt_format.format(**args)
 
57
  except Exception as e:
 
58
  return [{
59
  "status": "error",
60
  "reason": "Invalid format",
61
  "detail": str(e)
62
  }]
63
 
64
+ # Parse max_length, default to 212 if not provided or invalid
65
  max_length = data.get("max_length", 212)
66
  try:
67
  max_length = int(max_length)
 
68
  except ValueError:
69
  return [{
70
  "status": "error",
 
72
  "detail": "max_length was not a valid integer"
73
  }]
74
 
75
+ # Perform inference
76
+ try:
77
+ res = self.model(
78
+ formatted_prompt,
79
+ temperature=args["temperature"],
80
+ top_p=args["top_p"],
81
+ top_k=args["top_k"],
82
+ max_tokens=max_length
83
+ )
84
+ except Exception as e:
85
+ return [{
86
+ "status": "error",
87
+ "reason": "Inference failed",
88
+ "detail": str(e)
89
+ }]
90
 
91
  return [{
92
  "status": "success",
93
+ # Extract the text from the response
94
+ "response": res['choices'][0]['text'].strip()
95
  }]
96
+
97
+
98
+ # Usage in your script or where the handler is instantiated:
99
+ try:
100
+ handler = EndpointHandler("/repository")
101
+ except (FileNotFoundError, RuntimeError) as e:
102
+ print(f"Initialization error: {e}")
103
+ exit(1) # Exit with an error code if the handler cannot be initialized