huihui-ai commited on
Commit
cbd6171
·
verified ·
1 Parent(s): a4904a3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +142 -3
README.md CHANGED
@@ -1,3 +1,142 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model:
3
+ - huihui-ai/gemma-3-1b-it-abliterated
4
+ tags:
5
+ - text-generation-inference
6
+ - transformers
7
+ - unsloth
8
+ - gemma3_text
9
+ license: apache-2.0
10
+ language:
11
+ - en
12
+ datasets:
13
+ - huihui-ai/Guilherme34_uncensor
14
+ ---
15
+
16
+ # huihui-ai/gemma-3-1b-it-abliterated-GRPO
17
+
18
+ - **Developed by:** huihui-ai
19
+ - **License:** apache-2.0
20
+ - **Finetuned from model :** huihui-ai/gemma-3-1b-it-abliterated(https://huggingface.co/huihui-ai/gemma-3-1b-it-abliterated)
21
+
22
+ This gemma3_text model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
23
+
24
+ [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
25
+
26
+ ## Use with transformers
27
+
28
+ ```
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer
30
+ import torch
31
+ import os
32
+ import signal
33
+
34
+ cpu_count = os.cpu_count()
35
+ print(f"Number of CPU cores in the system: {cpu_count}")
36
+ half_cpu_count = cpu_count // 2
37
+ os.environ["MKL_NUM_THREADS"] = str(half_cpu_count)
38
+ os.environ["OMP_NUM_THREADS"] = str(half_cpu_count)
39
+ torch.set_num_threads(half_cpu_count)
40
+
41
+ print(f"PyTorch threads: {torch.get_num_threads()}")
42
+ print(f"MKL threads: {os.getenv('MKL_NUM_THREADS')}")
43
+ print(f"OMP threads: {os.getenv('OMP_NUM_THREADS')}")
44
+
45
+ # Load the model and tokenizer
46
+ NEW_MODEL_ID = "huihui-ai/gemma-3-1b-it-abliterated-GRPO"
47
+ print(f"Load Model {NEW_MODEL_ID} ... ")
48
+ quant_config_4 = BitsAndBytesConfig(
49
+ load_in_4bit=True,
50
+ bnb_4bit_compute_dtype=torch.bfloat16,
51
+ bnb_4bit_use_double_quant=True,
52
+ llm_int8_enable_fp32_cpu_offload=True,
53
+ )
54
+
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ NEW_MODEL_ID,
57
+ device_map="auto",
58
+ trust_remote_code=True,
59
+ #quantization_config=quant_config_4,
60
+ torch_dtype=torch.bfloat16
61
+ )
62
+ tokenizer = AutoTokenizer.from_pretrained(NEW_MODEL_ID, trust_remote_code=True)
63
+ if tokenizer.pad_token is None:
64
+ tokenizer.pad_token = tokenizer.eos_token
65
+ tokenizer.pad_token_id = tokenizer.eos_token_id
66
+
67
+ initial_messages = [{"role": "system", "content": "You are a helpful assistant."}]
68
+ messages = initial_messages.copy()
69
+
70
+ class CustomTextStreamer(TextStreamer):
71
+ def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
72
+ super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
73
+ self.generated_text = ""
74
+ self.stop_flag = False
75
+
76
+ def on_finalized_text(self, text: str, stream_end: bool = False):
77
+ self.generated_text += text
78
+ print(text, end="", flush=True)
79
+ if self.stop_flag:
80
+ raise StopIteration
81
+
82
+ def stop_generation(self):
83
+ self.stop_flag = True
84
+
85
+ def generate_stream(model, tokenizer, messages, max_new_tokens):
86
+ input_ids = tokenizer.apply_chat_template(
87
+ messages,
88
+ tokenize=True,
89
+ add_generation_prompt=True,
90
+ return_tensors="pt"
91
+ )
92
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
93
+ tokens = input_ids.to(model.device)
94
+ attention_mask = attention_mask.to(model.device)
95
+
96
+ streamer = CustomTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
97
+
98
+ def signal_handler(sig, frame):
99
+ streamer.stop_generation()
100
+ print("\n[Generation stopped by user with Ctrl+C]")
101
+
102
+ signal.signal(signal.SIGINT, signal_handler)
103
+
104
+ print("Response: ", end="", flush=True)
105
+ try:
106
+ generated_ids = model.generate(
107
+ tokens,
108
+ attention_mask=attention_mask,
109
+ use_cache=False,
110
+ max_new_tokens=max_new_tokens,
111
+ do_sample=True,
112
+ pad_token_id=tokenizer.pad_token_id,
113
+ streamer=streamer
114
+ )
115
+ del generated_ids
116
+ except StopIteration:
117
+ print("\n[Stopped by user]")
118
+
119
+ del input_ids, attention_mask
120
+ torch.cuda.empty_cache()
121
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
122
+
123
+ return streamer.generated_text, streamer.stop_flag
124
+
125
+ while True:
126
+ user_input = input("User: ").strip()
127
+ if user_input.lower() == "/exit":
128
+ print("Exiting chat.")
129
+ break
130
+ if user_input.lower() == "/clear":
131
+ messages = initial_messages.copy()
132
+ print("Chat history cleared. Starting a new conversation.")
133
+ continue
134
+ if not user_input:
135
+ print("Input cannot be empty. Please enter something.")
136
+ continue
137
+ messages.append({"role": "user", "content": user_input})
138
+ response, stop_flag = generate_stream(model, tokenizer, messages, 8192)
139
+ if stop_flag:
140
+ continue
141
+ messages.append({"role": "assistant", "content": response})
142
+ ```