VietnamAIHub commited on
Commit
b48c636
1 Parent(s): 9544ad6

update_model

Browse files
Files changed (1) hide show
  1. README.md +112 -3
README.md CHANGED
@@ -1,3 +1,112 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vietnamese Llama2-7B 8k Context Length with LoRA Adapters
2
+
3
+
4
+ This repository contains a Llama-7B model fine-tuned with QLoRA (Quantization Low-Rank Adapter) adapters. The adapter is a plug-and-play tool that enables the LLaMa model to perform well in many Vietnamese NLP tasks.
5
+ Project Github page: [Github](https://github.com/VietnamAIHub/Vietnamese_LLMs)
6
+ ## Model Overview
7
+
8
+ The Vietnamese Llama2-7B model is a large language model capable of generating meaningful text and can be used in a wide variety of natural language processing tasks, including text generation, sentiment analysis, and more. By using LoRA adapters, the model achieves better performance on low-resource tasks and demonstrates improved generalization.
9
+
10
+ ## Dataset and Fine-Tuning
11
+
12
+ The LLaMa2 model was fine-tuned on over 200K Vietnamese instructions from various sources to improve its ability to understand and generate text for different tasks. The instruction dataset comprises data from the following sources:
13
+ Dataset link: Comming soon
14
+
15
+ ## Testing the Model by yourself.
16
+
17
+ To load the fine-tuned Llama-7B model with LoRA adapters, follow the code snippet below:
18
+
19
+ ```python
20
+ import torch
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+
25
+ model_name = "VietnamAIHub/Vietnamese_llama2_7B_8K_SFT_General_domain"
26
+
27
+ ## Loading Base LLaMa model weight and Merge with Adapter Weight wiht the base model
28
+ m = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ load_in_8bit=True,
31
+ torch_dtype=torch.bfloat16,
32
+ pretraining_tp=1,
33
+ # use_auth_token=True,
34
+ # trust_remote_code=True,
35
+ cache_dir=cache_dir,
36
+ )
37
+
38
+ tok = AutoTokenizer.from_pretrained(
39
+ model_name,
40
+ cache_dir=cache_dir,
41
+ padding_side="right",
42
+ use_fast=False, # Fast tokenizer giving issues.
43
+ tokenizer_type='llama', #if 'llama' in args.model_name_or_path else None, # Needed for HF name change
44
+ use_auth_token=True,
45
+ )
46
+ tok.bos_token_id = 1
47
+ stop_token_ids = [0]
48
+
49
+ class StopOnTokens(StoppingCriteria):
50
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
51
+ for stop_id in stop_token_ids:
52
+ if input_ids[0][-1] == stop_id:
53
+ return True
54
+ return False
55
+
56
+ generation_config = dict(
57
+ temperature=0.2,
58
+ top_k=20,
59
+ top_p=0.9,
60
+ do_sample=True,
61
+ num_beams=1,
62
+ repetition_penalty=1.2,
63
+ max_new_tokens=400,
64
+ early_stopping=True,
65
+
66
+ )
67
+
68
+ prompts_input="Cách để học tập về một môn học thật tốt"
69
+ system_prompt=f"<s>[INST] <<SYS>>\n You are a helpful assistant, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
70
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
71
+ that your responses are socially unbiased and positive in nature.\
72
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
73
+ correct. If you don't know the answer to a question, please response as language model you are not able to respone detailed to these kind of question.\n<</SYS>>\n\n {prompts_input} [/INST] "
74
+
75
+
76
+
77
+ input_ids = tok(message, return_tensors="pt").input_ids
78
+ input_ids = input_ids.to(m.device)
79
+ stop = StopOnTokens()
80
+ streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
81
+
82
+ # #print(tok.decode(output[0]))
83
+ generation_config = dict(
84
+ temperature=0.1,
85
+ top_k=30,
86
+ top_p=0.95,
87
+ do_sample=True,
88
+ # num_beams=1,
89
+ repetition_penalty=1.2,
90
+ max_new_tokens=2048, ## 8K
91
+ early_stopping=True,
92
+ stopping_criteria=StoppingCriteriaList([stop]),
93
+ )
94
+ inputs = tok(message,return_tensors="pt") #add_special_tokens=False ?
95
+ generation_output = m.generate(
96
+ input_ids = inputs["input_ids"].to(device),
97
+ attention_mask = inputs['attention_mask'].to(device),
98
+ eos_token_id=tok.eos_token_id,
99
+ pad_token_id=tok.pad_token_id,
100
+ **generation_config
101
+ )
102
+ generation_output_ = m.generate(input_ids = inputs["input_ids"].to(device), **generation_config)
103
+
104
+ s = generation_output[0]
105
+ output = tok.decode(s,skip_special_tokens=True)
106
+ #response = output.split("### Output:")[1].strip()
107
+ print(output)
108
+ ```
109
+
110
+ ## Conclusion
111
+ The Vietnamese Llama2-7B with LoRA adapters is a versatile language model that can be utilized for a wide range of NLP tasks in Vietnamese. We hope that researchers and developers find this model useful and are encouraged to experiment with it in their projects.
112
+ For any questions, feedback, or contributions, please feel free to contact the maintainers of this repository TranNhiem 🙌: [Linkedin](https://www.linkedin.com/in/tran-nhiem-ab1851125/) [Twitter](https://twitter.com/TranRick2) [Facebook](https://www.facebook.com/jean.tran.336), Project [Discord](https://discord.gg/MC3yDZNz). Happy fine-tuning and experimenting with the Llama-30b model!