Vishal24's picture
Update README.md
798c004
|
raw
history blame
2.31 kB
metadata
library_name: peft
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0

Model Card for Model ID

Model Details

Model Description

  • Developed by: [More Information Needed]
  • Funded by [optional]: [More Information Needed]
  • Shared by [optional]: [More Information Needed]
  • Model type: [More Information Needed]
  • Language(s) (NLP): [More Information Needed]
  • License: [More Information Needed]
  • Finetuned from model [optional]: [More Information Needed]

Infrence function

def generate(review,category):
    # Define the roles and markers
    # Define the roles and markers
    B_INST, E_INST = "[INST]", "[/INST]"
    B_RW, E_RW = "[RW]", "[/RW]"

    
    user_prompt = f'Summarize the reviews for {category} category.' ### custom prompt here

      # Format your prompt template
    # prompt = f"{B_FUNC}{functionList.strip()}{E_FUNC}{B_INST} {user_prompt.strip()} {E_INST} Hello! Life is good, thanks for asking {B_INST} {user_prompt2.strip()} {E_INST} The most fun dog is the Labrador Retriever {B_INST} {user_prompt3.strip()} {E_INST}\n\n"
    prompt = f"{B_INST} {user_prompt.strip()} {E_INST}\n\n {B_RW} {review.strip()} {E_RW}\n"

    print("Prompt:")
    print(prompt)

    encoding = tokenizer(prompt, return_tensors="pt").to("cuda:0")
    output = model.generate(input_ids=encoding.input_ids,
                            attention_mask=encoding.attention_mask,
                            max_new_tokens=200,
                            do_sample=True,
                            temperature=0.01,
                            eos_token_id=tokenizer.eos_token_id,
                            top_k=0)

    print()

    # Subtract the length of input_ids from output to get only the model's response
    output_text = tokenizer.decode(output[0, len(encoding.input_ids[0]):], skip_special_tokens=False)
    output_text = re.sub('\n+', '\n', output_text)  # remove excessive newline characters

    print("Generated Assistant Response:")
    print(output_text)

    return output_text