TangrisJones commited on
Commit
5f6dd19
1 Parent(s): 05f197e

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +34 -0
  2. requirements.txt +2 -0
handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from typing import Dict, List, Any
4
+
5
+
6
+ class EndpointHandler:
7
+ # def __init__(self, path="decapoda-research/llama-65b-hf"):
8
+ def __init__(self, path="anon8231489123/vicuna-13b-GPTQ-4bit-128g"):
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForCausalLM.from_pretrained(path)
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
+ input_text = data["inputs"]
14
+ kwargs = data.get("kwargs", {})
15
+
16
+ # Tokenize input text
17
+ input_tokens = self.tokenizer.encode(input_text, return_tensors="pt")
18
+
19
+ # Generate output tokens
20
+ with torch.no_grad():
21
+ output_tokens = self.model.generate(input_tokens, **kwargs)
22
+
23
+ # Decode output tokens
24
+ output_text = self.tokenizer.decode(output_tokens[0])
25
+
26
+ return [{"output": output_text}]
27
+
28
+
29
+ # Example usage
30
+ if __name__ == "__main__":
31
+ handler = EndpointHandler()
32
+ input_data = {"inputs": "Once upon a time in a small village, "}
33
+ output_data = handler(input_data)
34
+ print(output_data)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.29.1
2
+ tokenizers==0.13.3