Update README.md
Browse files
README.md
CHANGED
@@ -27,38 +27,38 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
|
27 |
|
28 |
### Infrence function
|
29 |
|
30 |
-
def generate(review,category):
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
27 |
|
28 |
### Infrence function
|
29 |
|
30 |
+
def generate(review,category):
|
31 |
+
# Define the roles and markers
|
32 |
+
# Define the roles and markers
|
33 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
34 |
+
B_RW, E_RW = "[RW]", "[/RW]"
|
|
|
35 |
|
36 |
+
|
37 |
+
user_prompt = f'Summarize the reviews for {category} category.' ### custom prompt here
|
38 |
+
|
39 |
+
# Format your prompt template
|
40 |
+
# prompt = f"{B_FUNC}{functionList.strip()}{E_FUNC}{B_INST} {user_prompt.strip()} {E_INST} Hello! Life is good, thanks for asking {B_INST} {user_prompt2.strip()} {E_INST} The most fun dog is the Labrador Retriever {B_INST} {user_prompt3.strip()} {E_INST}\n\n"
|
41 |
+
prompt = f"{B_INST} {user_prompt.strip()} {E_INST}\n\n {B_RW} {review.strip()} {E_RW}\n"
|
42 |
+
|
43 |
+
print("Prompt:")
|
44 |
+
print(prompt)
|
45 |
+
|
46 |
+
encoding = tokenizer(prompt, return_tensors="pt").to("cuda:0")
|
47 |
+
output = model.generate(input_ids=encoding.input_ids,
|
48 |
+
attention_mask=encoding.attention_mask,
|
49 |
+
max_new_tokens=200,
|
50 |
+
do_sample=True,
|
51 |
+
temperature=0.01,
|
52 |
+
eos_token_id=tokenizer.eos_token_id,
|
53 |
+
top_k=0)
|
54 |
+
|
55 |
+
print()
|
56 |
+
|
57 |
+
# Subtract the length of input_ids from output to get only the model's response
|
58 |
+
output_text = tokenizer.decode(output[0, len(encoding.input_ids[0]):], skip_special_tokens=False)
|
59 |
+
output_text = re.sub('\n+', '\n', output_text) # remove excessive newline characters
|
60 |
+
|
61 |
+
print("Generated Assistant Response:")
|
62 |
+
print(output_text)
|
63 |
+
|
64 |
+
return output_text
|