chtmp223 commited on
Commit
18e5293
β€’
1 Parent(s): 4c1b998

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -2
README.md CHANGED
@@ -55,11 +55,47 @@ Use the code in [this repository](https://github.com/chtmp223/suri) for training
55
  | optim | adamw_torch |
56
  | per_device_train_batch_size | 1 |
57
 
58
-
59
- #### πŸ€— Software
60
 
61
  Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ## πŸ“œ Citation
64
 
65
  ```
 
55
  | optim | adamw_torch |
56
  | per_device_train_batch_size | 1 |
57
 
58
+ #### Software
 
59
 
60
  Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).
61
 
62
+ ## πŸ€— Inference
63
+
64
+ ```
65
+ from transformers import AutoTokenizer, AutoModelForCausalLM
66
+ from peft import PeftModel, PeftConfig
67
+ from datasets import load_dataset
68
+ import torch
69
+ os.environ["TOKENIZERS_PARALLELISM"] = "False"
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ torch.cuda.empty_cache()
72
+
73
+ model_name = "chtmp223/suri-i-orpo"
74
+ base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
75
+ config = PeftConfig.from_pretrained(model_name)
76
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
77
+ model = PeftModel.from_pretrained(base_model, model_name).to(device)
78
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
79
+ prompt = [
80
+ {
81
+ "role": "user",
82
+ "content": user_prompt,
83
+ }
84
+ ]
85
+ input_context = tokenizer.apply_chat_template(
86
+ prompt, add_generation_prompt=True, tokenize=False
87
+ )
88
+ input_ids = tokenizer.encode(
89
+ input_context, return_tensors="pt", add_special_tokens=False
90
+ ).to(model.device)
91
+ output = model.generate(
92
+ input_ids, max_length=10000, do_sample=True, use_cache=True
93
+ ).cpu()
94
+
95
+ print(tokenizer.decode(output[0]))
96
+ ```
97
+
98
+
99
  ## πŸ“œ Citation
100
 
101
  ```