File size: 3,684 Bytes
900a736
 
 
 
 
8539364
 
900a736
bb866ea
 
 
900a736
 
 
 
 
 
ed6a8c2
900a736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674c834
900a736
674c834
 
900a736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb90a80
900a736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674c834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8539364
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
---
license: mit
base_model:
- meta-llama/Meta-Llama-3-8B-Instruct
library_name: transformers
datasets:
- CreitinGameplays/reasoning-base-20k-llama3.1
---
# Meta Llama 3 8B Reasoning (Testing purpose only)

Code example:
```python
# test the model
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

def main():
    model_id = "CreitinGameplays/Llama-3.1-8b-reasoning-test"

    # Load the tokenizer.
    tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

    # Load the model using bitsandbytes 8-bit quantization if CUDA is available.
    if torch.cuda.is_available():
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            load_in_8bit=True,
            device_map="auto"
        )
        device = torch.device("cuda")
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id)
        device = torch.device("cpu")

    # Define the generation parameters.
    generation_kwargs = {
        "max_new_tokens": 2048,
        "do_sample": True,
        "temperature": 0.4,
        "top_k": 50,
        "top_p": 0.95,
        "repetition_penalty": 1.0,
        "num_return_sequences": 1,
        "forced_eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.eos_token_id
    }

    print("Enter your prompt (type 'exit' to quit):")
    while True:
        # Get user input.
        user_input = input("Input> ")
        if user_input.lower().strip() in ("exit", "quit"):
            break

        # Construct the prompt in your desired format.
        prompt = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant named Llama, made by Meta AI. Always use your <|end_reasoning|> token when ending a reason step, without text formatting.<|eot_id|><|start_header_id|>user<|end_header_id|>

{user_input}<|eot_id|><|start_header_id|>assistant<|end_header_id|> <|reasoning|>
"""

        # Tokenize the prompt and send to the selected device.
        input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)

        # Create a new TextStreamer instance for streaming responses.
        streamer = TextStreamer(tokenizer)
        generation_kwargs["streamer"] = streamer

        print("\nAssistant Response:")
        # Generate the text (tokens will stream to stdout via the streamer).
        outputs = model.generate(input_ids, **generation_kwargs)

if __name__ == "__main__":
    main()
```

Example model response:
```
<|begin_of_text|>
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant named Llama, made by Meta AI. Always use your <|end_reasoning|> when ending a reason step.<|eot_id|><|start_header_id|>user<|end_header_id|>

how many r's are in strawberry?<|eot_id|><|start_header_id|>assistant<|end_header_id|><|reasoning|>
To determine the number of R's in the word "strawberry", we need to break down the word into its individual letters.

The word "strawberry" can be spelled out as S-T-R-A-W-B-E-R-R-Y. Now, let's examine each letter to count the number of R's.

1. S: There is no R in the letter S.
2. T: There is no R in the letter T.
3. R: Ah, we have our first R!
4. A: There is no R in the letter A.
5. W: There is no R in the letter W.
6. B: There is no R in the letter B.
7. E: There is no R in the letter E.
8. R: Another R! We now have two R's.
9. R: And another R! We now have three R's.
10. Y: There is no R in the letter Y.

So, after counting all the letters, we can see that the word "strawberry" contains three R's.

Therefore, the answer is 3.

<|end_reasoning|>
3. The word "strawberry" contains 3 R's.<|eot_id|>
```