Edit model card

The code below shows how this Buyer Persona generator can be used.

More documentation coming soon...

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "danlou/persona-generator-llama-2-7b-qlora-merged"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def parse_outputs(output_text):

    try:
        output_lns = output_text.split('\n')
        assert len(output_lns) == 2
        assert len(output_lns[0].split(',')) == 2
        assert len(output_lns[1]) > 16

        name, age = [s.strip() for s in output_lns[0].split(',')]
        desc = output_lns[1].strip()

    except AssertionError:
        raise Exception('Malformed output.')
    
    try:
        age = int(age)
    except ValueError:
        raise Exception('Malformed output (age).')
    
    return {'name': name, 'age': age, 'description': desc}



def generate_personas(product, n=1, batch_size=32, parse=True):

    prompt = f"### Instruction:\nDescribe the ideal persona for this product:\n{product}\n\n### Response:\n"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    personas = []
    with tqdm(total=n) as pbar:
        for batch in chunks(range(n), batch_size):
            outputs = model.generate(input_ids,
                                    do_sample=True,
                                    num_beams=1,
                                    num_return_sequences=len(batch),
                                    max_length=512,
                                    min_length=32,
                                    temperature=0.9)

            for output_ids in outputs:
                output_decoded = tokenizer.decode(output_ids, skip_special_tokens=True)
                output_decoded = output_decoded[len(prompt):].strip()

                try:
                    if parse:
                        personas.append(parse_outputs(output_decoded))
                    else:
                        personas.append(output_decoded)
                except Exception as e:
                    print(e)
                    continue
            
            pbar.update(len(batch))
    
    return personas


product = "Koonie 10000mAh Rechargeable Desk Fan, 8-Inch Battery Operated Clip on Fan, USB Fan, 4 Speeds, Strong Airflow, Sturdy Clamp for Golf Cart Office Desk Outdoor Travel Camping Tent Gym Treadmill, Black (USB Gadgets > USB Fans)"
personas = generate_personas(product, n=3)

for e in personas:
    print(e)

# Persona 1 - The yoga instructor
# {'name': 'Sarah', 'age': 28, 'description': 'Yoga instructor who is passionate about health and fitness. She works from a home studio where she also practices yoga and meditation. Sarah values products that are eco-friendly and sustainable. She loves products that are versatile and can be used for different purposes. Sarah is looking for a product that is durable and can withstand frequent use. She values products that are stylish and aesthetically pleasing.'}
# Persona 2 - The golf enthusiast
#{'name': 'Sophia', 'age': 60, 'description': "Golf enthusiast. Sophia spends most of her weekends on the golf course, and she needs a fan that she can carry around in her golf cart. She needs a fan that's lightweight, easy to clip on, and has a long battery life. She also wants a fan that's affordable, especially since she plays at different courses."}
# Persona 3 - The truck driver
# {'name': 'Mike', 'age': 32, 'description': "Truck driver who spends most of his day on the road. The cab of his truck can get hot and stuffy, and Mike needs a fan that can keep him comfortable and alert while he's driving. He needs a fan that's easy to install and adjust, so he can keep it on his dashboard and direct the airflow where he needs it most."}
Downloads last month
3
Safetensors
Model size
6.74B params
Tensor type
FP16
·