Example script for using this model

#6
by JuLuComputing - opened

Hello all!

I was testing several of these Switch models and decided to share some of my code for testing. It has simple chat history which, unless you know what you're doing with a Switch model, the history can really confuse things.

Anyway, I tested using CPU only. This model while running, used an additional 60GB of ram, I would make sure you use this on a computer with at least 96GB of ram. Here is the script, if you have any questions, feel free to ask me:

import time
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration

print(f"Loading model, please wait...")

# Start timer for the entire script
start_time = time.time()

# Timer for model loading
loading_start_time = time.time()

tokenizer = AutoTokenizer.from_pretrained("google/switch-base-256")
model = SwitchTransformersForConditionalGeneration.from_pretrained("/whole/path/to/google_switch-base-256/", local_files_only=True)

loading_end_time = time.time()
loading_time = loading_end_time - loading_start_time

print(f"Model loading time: {loading_time:.2f} seconds")

# Simple conversation memory
conversation_history = []

while True:
    # Get user input
    user_input = input("User: ")

    # Add user input to conversation history
    conversation_history.append(f"User: {user_input}")

    # Combine conversation history into a single string
    conversation_text = " ".join(conversation_history)
    
    # Timer for generating the response
    generation_start_time = time.time()

    # Tokenize and generate response
    input_ids = tokenizer(conversation_text, return_tensors="pt").input_ids
    decoder_start_token_id = tokenizer.convert_tokens_to_ids("<pad>")
    max_new_tokens = 2048  # Adjust the value as needed
    outputs = model.generate(input_ids, decoder_start_token_id=decoder_start_token_id, max_new_tokens=max_new_tokens)

    generation_end_time = time.time()
    generation_time = generation_end_time - generation_start_time

    # Print the generated output
    generated_text = tokenizer.decode(outputs[0])
    print(f"SwitchBot: {generated_text}")

    print(f"Response generation time: {generation_time:.2f} seconds")

    # Token counter
    response_tokens = len(outputs[0])
    tokens_per_second = response_tokens / generation_time

    print(f"Number of response tokens: {response_tokens}")
    print(f"Tokens per second: {tokens_per_second:.2f}")

    # Add SwitchBot's response to conversation history
    conversation_history.append(f"SwitchBot: {generated_text}")

Note: I downloaded all the files to a local directory and ran the script from that same directory. You will need to edit this script to tell it that same whole directory path.

Thanks for sharing @JuLuComputing !

Google org

Thanks @JuLuComputing !

You bet, peeps! 😁 I'm glad to help the community. Let me know if there are any questions on the script.

Sign up or log in to comment