PEFT
Safetensors
German
trl
sft
Generated from Trainer
JanPf commited on
Commit
8ba72f5
·
verified ·
1 Parent(s): a2e8f54

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -1
README.md CHANGED
@@ -18,4 +18,59 @@ license: other
18
  # LLäMmlein 1B Chat
19
 
20
  This is a chat adapter for the German Tinyllama 1B language model.
21
- Find more details on our [page](https://www.informatik.uni-wuerzburg.de/datascience/projects/nlp/llammlein/) and our [preprint](arxiv.org/abs/2411.11171)!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # LLäMmlein 1B Chat
19
 
20
  This is a chat adapter for the German Tinyllama 1B language model.
21
+ Find more details on our [page](https://www.informatik.uni-wuerzburg.de/datascience/projects/nlp/llammlein/) and our [preprint](arxiv.org/abs/2411.11171)!
22
+
23
+ ## Run it
24
+ ```py
25
+ import torch
26
+ from peft import PeftConfig, PeftModel
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+
29
+ torch.manual_seed(42)
30
+
31
+ # script config
32
+ base_model_name = "LSX-UniWue/llammchen_1b"
33
+ chat_adapter_name = "LSX-UniWue/LLaMmlein_1B_chat_sharegpt"
34
+ device = "mps" # or cuda
35
+
36
+ # chat history
37
+ messages = [
38
+ {
39
+ "role": "user",
40
+ "content": """Na wie geht's?""",
41
+ },
42
+ ]
43
+
44
+ # load model
45
+ config = PeftConfig.from_pretrained(chat_adapter_name)
46
+ base_model = model = AutoModelForCausalLM.from_pretrained(
47
+ base_model_name,
48
+ attn_implementation="flash_attention_2" if device == "cuda" else None,
49
+ torch_dtype=torch.bfloat16,
50
+ device_map=device,
51
+ )
52
+ base_model.resize_token_embeddings(32064)
53
+ model = PeftModel.from_pretrained(base_model, chat_adapter_name)
54
+ tokenizer = AutoTokenizer.from_pretrained(chat_adapter_name)
55
+
56
+ # encode message in "ChatML" format
57
+ chat = tokenizer.apply_chat_template(
58
+ messages,
59
+ return_tensors="pt",
60
+ add_generation_prompt=True,
61
+ ).to(device)
62
+
63
+ # generate response
64
+ print(
65
+ tokenizer.decode(
66
+ model.generate(
67
+ chat,
68
+ max_new_tokens=300,
69
+ pad_token_id=tokenizer.pad_token_id,
70
+ eos_token_id=tokenizer.eos_token_id,
71
+ )[0],
72
+ skip_special_tokens=False,
73
+ )
74
+ )
75
+
76
+ ```