| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
| import gradio as gr |
|
|
| |
| base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| adapter_path = "BeDream/tuning-lora-tinyllama-1.1b" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(adapter_path) |
|
|
| |
| if tokenizer.eos_token is None: |
| tokenizer.eos_token = "<|endoftext|>" |
| tokenizer.add_special_tokens({'eos_token': tokenizer.eos_token}) |
|
|
| if tokenizer.pad_token is None or tokenizer.pad_token == tokenizer.eos_token: |
| print("⚠️ pad_token sama dengan eos_token. Menambahkan pad_token baru...") |
| tokenizer.pad_token = '[PAD]' |
| tokenizer.add_special_tokens({'pad_token': tokenizer.pad_token}) |
|
|
| |
| tokenizer.chat_template = """ |
| {% for message in messages %} |
| {% if message['role'] == 'user' %} |
| {{ '<|user|>\n' + message['content'] + eos_token }} |
| {% elif message['role'] == 'system' %} |
| {{ '<|system|>\n' + message['content'] + eos_token }} |
| {% elif message['role'] == 'assistant' %} |
| {{ '<|assistant|>\n' + message['content'] + eos_token }} |
| {% endif %} |
| {% if loop.last and add_generation_prompt %} |
| {{ '<|assistant|>' }} |
| {% endif %} |
| {% endfor %} |
| """ |
|
|
| |
| model = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| |
| model = PeftModel.from_pretrained(model, adapter_path) |
|
|
| |
| def chat_fn(message, history): |
| |
| messages = [{"role": "user", "content": message}] |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=200, |
| do_sample=True, |
| top_p=0.95, |
| temperature=0.7, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
|
|
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| if "<|assistant|>" in response: |
| response = response.split("<|assistant|>")[-1].strip() |
| return response |
|
|
| |
| iface = gr.ChatInterface(fn=chat_fn) |
| iface.launch() |
|
|