Sirawitch commited on
Commit
a90d622
1 Parent(s): fe70976

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -24
app.py CHANGED
@@ -2,11 +2,9 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
6
  import logging
7
- import os
8
 
9
- # ตั้งค่า logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
@@ -16,32 +14,21 @@ try:
16
  model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
19
- # ตรวจสอบว่ามี GPU หรือไม่
20
- if torch.cuda.is_available():
21
- logger.info("GPU is available. Using CUDA.")
22
- device = "cuda"
23
- else:
24
- logger.info("No GPU found. Using CPU.")
25
- device = "cpu"
26
 
27
- # กำหนดการตั้งค่าสำหรับการโหลดโมเดล
28
- model_kwargs = {
29
- "torch_dtype": torch.float32 if device == "cpu" else torch.float16,
30
- "low_cpu_mem_usage": True,
31
- }
32
-
33
- if device == "cuda":
34
- from transformers import BitsAndBytesConfig
35
- model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
36
 
37
- # โหลดโมเดล
38
  model = AutoModelForCausalLM.from_pretrained(
39
  model_name,
40
- device_map="auto" if device == "cuda" else None,
41
- **model_kwargs
 
42
  )
43
-
44
- model.to(device)
45
  logger.info(f"Model loaded successfully on {device}")
46
  except Exception as e:
47
  logger.error(f"Error loading model: {str(e)}")
 
2
  from pydantic import BaseModel
3
  from typing import Optional
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
  import logging
 
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
 
14
  model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ logger.info(f"Using device: {device}")
 
 
 
 
 
19
 
20
+ # 4-bit quantization configuration
21
+ quantization_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_compute_dtype=torch.float16
24
+ )
 
 
 
 
25
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_name,
28
+ quantization_config=quantization_config,
29
+ device_map="auto",
30
+ low_cpu_mem_usage=True,
31
  )
 
 
32
  logger.info(f"Model loaded successfully on {device}")
33
  except Exception as e:
34
  logger.error(f"Error loading model: {str(e)}")