Adapters
Inference Endpoints
jeremyarancio commited on
Commit
88e1248
1 Parent(s): 6dec8ee

Update handler

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. handler.py +8 -9
README.md CHANGED
@@ -26,10 +26,10 @@ from peft import PeftConfig, PeftModel
26
 
27
  # Import the model
28
  config = PeftConfig.from_pretrained("JeremyArancio/llm-tolkien")
29
- model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')
30
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
31
  # Load the Lora model
32
- model = PeftModel.from_pretrained(model, hf_repo)
33
  ```
34
 
35
  # Run the model
 
26
 
27
  # Import the model
28
  config = PeftConfig.from_pretrained("JeremyArancio/llm-tolkien")
29
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
30
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
31
  # Load the Lora model
32
+ model = PeftModel.from_pretrained(model, "JeremyArancio/llm-tolkien")
33
  ```
34
 
35
  # Run the model
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Any
2
  import logging
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -6,6 +6,7 @@ from peft import PeftConfig, PeftModel
6
 
7
 
8
  LOGGER = logging.getLogger(__name__)
 
9
 
10
 
11
  class EndpointHandler():
@@ -16,26 +17,24 @@ class EndpointHandler():
16
  # Load the Lora model
17
  self.model = PeftModel.from_pretrained(model, path)
18
 
19
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
  """
21
  Args:
22
  data (Dict): The payload with the text prompt and generation parameters.
23
  """
24
  LOGGER.info(f"Received data: {data}")
25
  # Get inputs
26
- inputs = data.pop("inputs", data)
27
  parameters = data.pop("parameters", None)
28
- LOGGER.info("Data extracted.")
29
  # Preprocess
30
- LOGGER.info(f"Start tokenizer: {inputs}")
31
- inputs_ids = self.tokenizer(inputs, return_tensors="pt").inputs_ids
32
  # Forward
33
  LOGGER.info(f"Start generation.")
34
  if parameters is not None:
35
- outputs = self.model.generate(inputs_ids, **parameters)
36
  else:
37
- outputs = self.model.generate(inputs_ids)
38
  # Postprocess
39
- prediction = self.tokenizer.decode(outputs[0])
40
  LOGGER.info(f"Generated text: {prediction}")
41
  return {"generated_text": prediction}
 
1
+ from typing import Dict, Any
2
  import logging
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
 
8
  LOGGER = logging.getLogger(__name__)
9
+ logging.basicConfig(level=logging.INFO)
10
 
11
 
12
  class EndpointHandler():
 
17
  # Load the Lora model
18
  self.model = PeftModel.from_pretrained(model, path)
19
 
20
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
21
  """
22
  Args:
23
  data (Dict): The payload with the text prompt and generation parameters.
24
  """
25
  LOGGER.info(f"Received data: {data}")
26
  # Get inputs
27
+ prompt = data.pop("prompt", data)
28
  parameters = data.pop("parameters", None)
 
29
  # Preprocess
30
+ input = self.tokenizer(prompt, return_tensors="pt")
 
31
  # Forward
32
  LOGGER.info(f"Start generation.")
33
  if parameters is not None:
34
+ output = self.model.generate(**input, **parameters)
35
  else:
36
+ output = self.model.generate(**input)
37
  # Postprocess
38
+ prediction = self.tokenizer.decode(output[0])
39
  LOGGER.info(f"Generated text: {prediction}")
40
  return {"generated_text": prediction}