niruemon commited on
Commit
2e90b71
1 Parent(s): 67b1056

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -14
handler.py CHANGED
@@ -1,28 +1,32 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
  import torch
3
  import os
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # ระบุชื่อโมเดลใน Hugging Face Hub
8
- model_name = "niruemon/llm-swp"
 
 
 
 
 
9
 
10
  # กำหนดไดเรกทอรีสำหรับการ offload โมเดล (สร้างขึ้นถ้ายังไม่มี)
11
  offload_dir = "./offload"
12
  os.makedirs(offload_dir, exist_ok=True)
13
 
14
  # โหลดโมเดลและ tokenizer
15
- self.model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- device_map="auto",
18
- torch_dtype=torch.float16,
19
- offload_state_dict=True, # เปิดใช้งานการ offload state dict
20
- offload_dir=offload_dir # ระบุ `offload_dir` โดยตรงเพื่อจัดการการ offload ให้ถูกต้อง
21
  )
22
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
24
- # สร้าง pipeline สำหรับการสร้างข้อความ
25
- self.generator = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device_map="auto", offload_folder=offload_dir)
26
 
27
  def __call__(self, data):
28
  # รับข้อความ input จากผู้ใช้
@@ -32,8 +36,9 @@ class EndpointHandler:
32
 
33
  # สร้างข้อความโดยใช้โมเดล
34
  try:
35
- result = self.generator(input_text, max_length=150, num_return_sequences=1)
36
- generated_text = result[0]["generated_text"]
 
37
  return {"generated_text": generated_text}
38
  except Exception as e:
39
  return {"error": str(e)}
 
1
+ from unsloth import FastLanguageModel
2
  import torch
3
  import os
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ # ระบุชื่อโมเดลที่คุณต้องการใช้งาน
8
+ model_name = "defog/llama-3-sqlcoder-8b"
9
+
10
+ # Configuration settings
11
+ max_seq_length = 2048
12
+ dtype = None # Keep as None, you can change later if needed
13
+ load_in_4bit = True
14
 
15
  # กำหนดไดเรกทอรีสำหรับการ offload โมเดล (สร้างขึ้นถ้ายังไม่มี)
16
  offload_dir = "./offload"
17
  os.makedirs(offload_dir, exist_ok=True)
18
 
19
  # โหลดโมเดลและ tokenizer
20
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
21
+ model_name=model_name,
22
+ max_seq_length=max_seq_length,
23
+ dtype=dtype,
24
+ load_in_4bit=load_in_4bit,
25
+ offload_folder=offload_dir # ระบุโฟลเดอร์สำหรับการ offload
26
  )
 
27
 
28
+ # เตรียมโมเดลสำหรับการประมวลผลข้อความ
29
+ FastLanguageModel.for_inference(self.model)
30
 
31
  def __call__(self, data):
32
  # รับข้อความ input จากผู้ใช้
 
36
 
37
  # สร้างข้อความโดยใช้โมเดล
38
  try:
39
+ inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
40
+ outputs = self.model.generate(**inputs, max_new_tokens=150)
41
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
42
  return {"generated_text": generated_text}
43
  except Exception as e:
44
  return {"error": str(e)}