Instructions to use SPAISS6F1/gemma-1b-pruned-th with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SPAISS6F1/gemma-1b-pruned-th with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="SPAISS6F1/gemma-1b-pruned-th")# Load model directly from transformers import AutoProcessor, AutoModelForMultimodalLM processor = AutoProcessor.from_pretrained("SPAISS6F1/gemma-1b-pruned-th") model = AutoModelForMultimodalLM.from_pretrained("SPAISS6F1/gemma-1b-pruned-th") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use SPAISS6F1/gemma-1b-pruned-th with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "SPAISS6F1/gemma-1b-pruned-th" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "SPAISS6F1/gemma-1b-pruned-th", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/SPAISS6F1/gemma-1b-pruned-th
- SGLang
How to use SPAISS6F1/gemma-1b-pruned-th with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "SPAISS6F1/gemma-1b-pruned-th" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "SPAISS6F1/gemma-1b-pruned-th", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "SPAISS6F1/gemma-1b-pruned-th" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "SPAISS6F1/gemma-1b-pruned-th", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use SPAISS6F1/gemma-1b-pruned-th with Docker Model Runner:
docker model run hf.co/SPAISS6F1/gemma-1b-pruned-th
gemma-1b-pruned-th
โมเดลภาษาไทยขนาดเล็กที่ได้จากการ Depth Pruning (Layer Dropping) ของ unsloth/gemma-3-4b-pt (mirror ของ google/gemma-3-4b-pt)
แล้วทำ Healing SFT เพื่อกู้ความสามารถกลับมา
| Base model | unsloth/gemma-3-4b-pt (mirror ของ google/gemma-3-4b-pt) |
| Base size | 4.30B (34 layers) |
| โมเดลนี้ | 2.70B (เก็บ 17 layers) |
| Layers ที่เก็บ | [0-7, 25-33] (ตัด layer กลาง เก็บหัว+ท้าย) |
| Healing data | SEA-PILE v2 Thai (~8,000 docs) |
| Hardware | NVIDIA A100-40GB (Lanta HPC) |
| Requires | transformers>=4.50, accelerate |
Pipeline การสร้างโมเดล (ทำซ้ำได้)
ขั้นที่ 1 — Depth Pruning (Layer Dropping)
ตัด decoder layer ตรงกลางทิ้ง (มักทำงานซ้ำซ้อน) เก็บเฉพาะ layer หัว (เข้าใจ input) และ layer ท้าย (สร้าง output) — embedding / lm_head / norm คงเดิม จึงไม่พัง dimension
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("unsloth/gemma-3-4b-pt (mirror ของ google/gemma-3-4b-pt)", torch_dtype=torch.bfloat16)
# หา text-decoder layer list (เลี่ยง vision encoder กรณี multimodal)
holder, layers = None, None
for _, mod in model.named_modules():
L = getattr(mod, "layers", None)
if isinstance(L, torch.nn.ModuleList) and len(L) and hasattr(L[0], "self_attn"):
holder, layers = mod, L
if "language" in _.lower() or "text" in _.lower():
break
N = len(layers)
# เก็บ 17 จาก 34 layers: หัว + ท้าย
keep = [0,1,2,3,4,5,6,7, 25,26,27,28,29,30,31,32,33]
holder.layers = torch.nn.ModuleList([layers[i] for i in keep])
# อัปเดต config (รองรับ nested text_config ของ Gemma3)
for c in {model.config, getattr(model.config, "text_config", model.config)}:
if getattr(c, "num_hidden_layers", None) is not None:
c.num_hidden_layers = len(keep)
lt = getattr(c, "layer_types", None)
if isinstance(lt, list) and len(lt) == N:
c.layer_types = [lt[i] for i in keep]
# reindex layer_idx ของแต่ละ block (สำคัญต่อ KV cache)
for i, lyr in enumerate(holder.layers):
if hasattr(lyr, "self_attn") and hasattr(lyr.self_attn, "layer_idx"):
lyr.self_attn.layer_idx = i
ผลลัพธ์: 4.30B -> 2.70B (ยังไม่ถึง 1B เป๊ะ เพราะ embedding+lm_head+vocab ไม่ลดตาม layer)
หลัง prune โมเดลจะพ่น gibberish ทันที (เส้นประสาทถูกตัดขาด) -> ต้อง Healing ต่อ
ขั้นที่ 2 — Healing SFT
เทรนต่อด้วย causal-LM บน Thai corpus เพื่อให้ layer ที่เหลือกลับมาทำงานร่วมกัน
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
ds = load_dataset("SPAISS6F1/spai-ss6-llm-1b-thai-corpus", split="train") # หรือ SEA-PILE v2 'th'
tds = ds.map(lambda e: tok(e["text"], truncation=True, max_length=1024),
batched=True, remove_columns=ds.column_names)
model.gradient_checkpointing_enable(); model.config.use_cache = False
args = TrainingArguments(output_dir="out", num_train_epochs=2,
per_device_train_batch_size=4, gradient_accumulation_steps=4,
learning_rate=5e-5, lr_scheduler_type="cosine", warmup_ratio=0.03, bf16=True)
Trainer(model=model, args=args, train_dataset=tds,
data_collator=DataCollatorForLanguageModeling(tok, mlm=False)).train()
Hyperparameters:
- Learning rate:
5e-5(สูงกว่าปกติเพื่อสมานแผล) | Epochs: 2 - Batch 4 x grad-accum 4 (effective 16) | max_len 1024 | bf16
- Optimizer: AdamW + cosine schedule, warmup 3%
- Env: venv overlay (transformers 4.53.3)
ขั้นที่ 3 — Save
model.config.use_cache = True
try:
model.save_pretrained("out", safe_serialization=True)
except RuntimeError: # Gemma3: tied embeddings -> fallback .bin
model.save_pretrained("out", safe_serialization=False)
โมเดลนี้ save เป็น: pytorch_model.bin (sharded; safetensors ไม่ได้เพราะ tied embeddings)
วิธีใช้ (Inference)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
m = "SPAISS6F1/gemma-1b-pruned-th"
tok = AutoTokenizer.from_pretrained(m)
model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16, device_map="cuda")
ids = tok("ปัญญาประดิษฐ์ คือ", return_tensors="pt").to(model.device)
out = model.generate(**ids, max_new_tokens=120, do_sample=True,
temperature=0.7, top_p=0.9, repetition_penalty=1.3)
print(tok.decode(out[0], skip_special_tokens=True))
ข้อควรรู้ / ข้อจำกัด
- เป็น pruned base ที่ heal ด้วย raw web corpus -> ไวยากรณ์ไทยลื่นไหลดี แต่ ข้อเท็จจริงและการคิดเลขยังอ่อน (ยังไม่ผ่าน instruction tuning)
- แนะนำ
repetition_penalty >= 1.2กันการวนซ้ำ - เหมาะเป็น base สำหรับ fine-tune ต่อด้วย instruction dataset มากกว่าใช้ตอบตรง ๆ
- การตัด layer 50% เป็นการตัดที่ค่อนข้างหนัก (งานวิจัย เช่น ShortGPT แนะ ~25%); ถ้าต้องการคุณภาพสูงขึ้นควร heal นานขึ้น/ตัดเบาลง
- Downloads last month
- 101