Adapters
Inference Endpoints
JeremyArancio commited on
Commit
c8b5fa1
1 Parent(s): 571225f

Add handler and requirements

Browse files
Files changed (2) hide show
  1. handler.py +48 -0
  2. requirements.txt +1 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftConfig, PeftModel
5
+
6
+
7
+ class EndpointHandler():
8
+ def __init__(self, path=""):
9
+ config = PeftConfig.from_pretrained(path)
10
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')
11
+ self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
12
+ # Load the Lora model
13
+ self.model = PeftModel.from_pretrained(model, path)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ data args:
18
+ prompt (:obj:`str`):
19
+ temperature (:obj:`float`, `optional`, defaults to 0.5):
20
+ eos_token_id (:obj:`int`, `optional`, defaults to tokenizer.eos_token_id):
21
+ early_stopping (:obj:`bool`, `optional`, defaults to `True`):
22
+ repetition_penalty (:obj:`float`, `optional`, defaults to 0.3):
23
+ Return:
24
+ A :obj:`str` : generated sequences
25
+ """
26
+ # Get inputs
27
+ prompt = data.pop("prompt", None)
28
+ temperature = data.pop("temperature", 0.5)
29
+ eos_token_id = data.pop("eos_token_id", self.tokenizer.eos_token_id)
30
+ early_stopping = data.pop('early_stopping', True)
31
+ repetition_penalty = data.pop('repetition_penalty', 0.3)
32
+ max_new_tokens = data.pop('max_new_tokens', 100)
33
+
34
+ if prompt is None:
35
+ raise ValueError("No prompt provided.")
36
+
37
+ # Run prediction
38
+ inputs = self.tokenizer(prompt, return_tensors="pt")
39
+ prediction = self.model.generate(
40
+ **inputs,
41
+ temperature=temperature,
42
+ eos_token_id=eos_token_id,
43
+ early_stopping=early_stopping,
44
+ repetition_penalty=repetition_penalty,
45
+ max_new_tokens=max_new_tokens
46
+ )
47
+
48
+ return prediction
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ peft==0.3.0