subhankarfynd commited on
Commit
a9e85ee
·
1 Parent(s): ba48f5d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +71 -0
handler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import bitsandbytes as bnb
3
+ import torch
4
+ import transformers
5
+ from datasets import load_dataset
6
+ from typing import Dict, List, Any
7
+ from peft import (
8
+ LoraConfig,
9
+ PeftConfig,
10
+ PeftModel,
11
+ get_peft_model,
12
+ prepare_model_for_kbit_training,
13
+ )
14
+ from transformers import (
15
+ AutoConfig,
16
+ LlamaTokenizer,
17
+ LlamaForCausalLM,
18
+ #AutoModelForCausalLM,
19
+ #AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ )
22
+ import json
23
+
24
+ bnb_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_compute_dtype=torch.bfloat16,
29
+ )
30
+
31
+
32
+ from huggingface_hub import login
33
+ access_token_read = "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"
34
+ login(token = access_token_read)
35
+
36
+
37
+ class EndpointHandler:
38
+ def __init__(self, path=''):
39
+ PEFT_MODEL = path
40
+ config = PeftConfig.from_pretrained(PEFT_MODEL)
41
+ self.model = LlamaForCausalLM.from_pretrained(
42
+ config.base_model_name_or_path,
43
+ return_dict=True,
44
+ quantization_config=bnb_config,
45
+ device_map="auto",
46
+ trust_remote_code=True,
47
+ )
48
+ self.tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path)
49
+ self.tokenizer.pad_token_id = (0)
50
+ self.tokenizer.padding_side = "left"
51
+ self.model = PeftModel.from_pretrained(self.model, PEFT_MODEL)
52
+ self.generation_config = self.model.generation_config
53
+ self.generation_config.max_new_tokens = 500
54
+ self.generation_config.pad_token_id = self.tokenizer.eos_token_id
55
+ self.generation_config.eos_token_id = self.tokenizer.eos_token_id
56
+
57
+
58
+
59
+
60
+ def __call__(self, data: Dict[str, Any]):
61
+ prompt = data.pop("inputs", data)
62
+ DEVICE = "cuda:0"
63
+ input_message = f"""[INST]You are Copilot, a chat assistant that helps users choose products from JioMart, JioFiber, JioCinema, Tira Beauty, netmeds and milkbasket[/INST]\nUser: {prompt}\nAssistant: """.strip()
64
+ encoding = self.tokenizer(input_message, return_tensors="pt").to(DEVICE)
65
+ with torch.inference_mode():
66
+ outputs = self.model.generate(
67
+ input_ids=encoding.input_ids,
68
+ attention_mask=encoding.attention_mask,
69
+ generation_config=self.generation_config
70
+ )
71
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_message):]