Vladislav Sokolovskii commited on
Commit
18e819d
1 Parent(s): 9b71b69

Add handler and reqs

Browse files
Files changed (2) hide show
  1. handler.py +71 -0
  2. requirements.txt +8 -0
handler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Any
3
+ from unsloth import FastLanguageModel
4
+ from unsloth.chat_templates import get_chat_template
5
+ import torch
6
+ from huggingface_hub import login
7
+ import os
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # access_token = os.environ["HUGGINGFACE_TOKEN"]
12
+ # login(token=access_token)
13
+ # Load the model and tokenizer
14
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
15
+ model_name = path, # Use the current directory path
16
+ max_seq_length = 2048,
17
+ dtype = None,
18
+ load_in_4bit = True,
19
+ )
20
+ FastLanguageModel.for_inference(self.model)
21
+
22
+ # Set up the chat template
23
+ self.tokenizer = get_chat_template(
24
+ self.tokenizer,
25
+ chat_template="llama-3",
26
+ mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}
27
+ )
28
+
29
+ def __call__(self, data: Dict[str, Any]) -> List[str]:
30
+ inputs = data.pop("inputs", data)
31
+ parameters = data.pop("parameters", {})
32
+
33
+ # Extract parameters or use defaults
34
+ max_tokens = parameters.get("max_new_tokens", 512)
35
+ temperature = parameters.get("temperature", 0.2)
36
+ top_p = parameters.get("top_p", 0.5)
37
+ system_message = parameters.get("system_message", "")
38
+
39
+ # Prepare messages
40
+ messages = [{"from": "human", "value": system_message}]
41
+ if isinstance(inputs, str):
42
+ messages.append({"from": "human", "value": inputs})
43
+ elif isinstance(inputs, list):
44
+ for msg in inputs:
45
+ role = "human" if msg["role"] == "user" else "gpt"
46
+ messages.append({"from": role, "value": msg["content"]})
47
+
48
+ # Tokenize input
49
+ tokenized_input = self.tokenizer.apply_chat_template(
50
+ messages,
51
+ tokenize=True,
52
+ add_generation_prompt=True,
53
+ return_tensors="pt"
54
+ ).to("cuda")
55
+
56
+ # Generate output
57
+ with torch.no_grad():
58
+ output = self.model.generate(
59
+ input_ids=tokenized_input,
60
+ max_new_tokens=max_tokens,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ use_cache=True
64
+ )
65
+
66
+ # Decode and process the output
67
+ full_response = self.tokenizer.decode(output[0], skip_special_tokens=True)
68
+ response_lines = [line.strip() for line in full_response.split('\n') if line.strip()]
69
+ last_response = response_lines[-1] if response_lines else ""
70
+
71
+ return [last_response]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torchvision
2
+ xformers<0.0.27
3
+ trl==0.8.6
4
+ transformers==4.44.2
5
+ bitsandbytes==0.43.3
6
+ peft==0.12.0
7
+ accelerate>=0.34.2
8
+ git+https://github.com/unslothai/unsloth.git