Edit model card

Llama 3 8B Robot Instruction Model (4-bit)

Model description

This model is a fine-tuned version of Llama 3 8B, optimized with Unsloth and quantized into 4-bit. It is designed to convert casual user input text into function calls for controlling industrial robots. The aim is to lower the barrier for individuals who do not have programming skills to control robots using simple text instructions.

Model Details

  • Model ID: Studeni/llama-3-8b-bnb-4bit-robot-instruct
  • Architecture: Llama 3 8B
  • Quantization: 4-bit
  • Framework: Transformers, Peft, Unsloth

Usage

Using Unsloth Library

import json
from datasets import load_dataset
from unsloth import FastLanguageModel

# Dataset
repo_id = "Studeni/robot-instructions"
dataset = load_dataset(repo_id, split="test")
test_input = dataset[0]["input"]
test_output = dataset[0]["output"]
print(f"User input: {test_input}\nGround truth: {test_output}")

# Prompt
robot_instruct_prompt = """
### Instruction:
Transform input into list of function calls for controlling industrial robots.

### Input:
{}

### Response:
{}
"""

# Model Parameters
lora_id = "Studeni/llama-3-8b-bnb-4bit-robot-instruct"
max_seq_length = 2048
dtype = None  # Auto-detection. Use Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True

# Load the model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=lora_id,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)

# Tokenize input text
inputs = tokenizer(
    [robot_instruct_prompt.format(test_input, "")],
    return_tensors="pt",
).to("cuda")

# Run generation
outputs = model.generate(**inputs, max_new_tokens=64, use_cache=True)
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Extracting function call and converting to json
function_call = text_output[0].split("### Response:")[-1].strip()
function_call = json.loads(function_call)
for f in function_call:
    print(f"Function to call: {f['function']}")
    print(f"Input parameters: {f['kwargs']}")

Using Transformers and Peft

import json
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

# Dataset
repo_id = "Studeni/robot-instructions"
dataset = load_dataset(repo_id, split="test")
test_input = dataset[0]["input"]
test_output = dataset[0]["output"]
print(f"User input: {test_input}\nGround truth: {test_output}")

# Prompt
robot_instruct_prompt = """
### Instruction:
Transform input into list of function calls for controlling industrial robots.

### Input:
{}

### Response:
{}
"""

# Model Parameters
lora_id = "Studeni/llama-3-8b-bnb-4bit-robot-instruct"
load_in_4bit = True

# Load model and tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=lora_id,
    load_in_4bit=load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(lora_id)

# Tokenize input text
inputs = tokenizer(
    [robot_instruct_prompt.format(test_input, "")],
    return_tensors="pt",
).to("cuda")

# Run generation
outputs = model.generate(**inputs, max_new_tokens=256, use_cache=True)
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Extracting function call and converting to json
function_call = text_output[0].split("### Response:")[-1].strip()
function_call = json.loads(function_call)
for f in function_call:
    print(f"Function to call: {f['function']}")
    print(f"Input parameters: {f['kwargs']}")

Limitations and Future Work 🚨

This model is currently a work in progress and supports only three basic functions: move_tcp, move_joint, and get_joint_values. Future iterations will include a more comprehensive dataset with more complex commands and capabilities, better human-labeled data, and improved performance metrics.

Contributions and Collaborations 🀝

We welcome contributions and collaborations to help improve and expand the capabilities of this model. Whether you are interested in adding more complex functions, improving the dataset, or enhancing the model's performance, your input is valuable. You can add and contact me on LinkedIn.


This llama model was trained 2x faster with Unsloth and Huggingface's TRL library.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for Studeni/llama-3-8b-bnb-4bit-robot-instruct

Finetuned
this model

Dataset used to train Studeni/llama-3-8b-bnb-4bit-robot-instruct