|
--- |
|
license: apache-2.0 |
|
pipeline_tag: text-classification |
|
basemodel: roberta-base |
|
datasets: |
|
- DevQuasar/llm_router_dataset-synth |
|
language: |
|
- en |
|
--- |
|
|
|
Intention of the model is to determine if the given user prompt's complexity, domain question requires a SOTA (very large) LLM |
|
or can be deescaleted to a smaller or local model. |
|
|
|
Example code: |
|
|
|
``` |
|
from openai import OpenAI |
|
from datasets import load_dataset |
|
from datasets.dataset_dict import DatasetDict |
|
import json |
|
import random |
|
from transformers import ( |
|
RobertaTokenizerFast, |
|
RobertaForSequenceClassification, |
|
) |
|
from transformers import pipeline |
|
|
|
model_id = 'DevQuasar/roberta-prompt_classifier-v0.1' |
|
tokenizer = RobertaTokenizerFast.from_pretrained(model_id) |
|
sentence_classifier = pipeline( |
|
"sentiment-analysis", model=model_id, tokenizer=tokenizer |
|
) |
|
|
|
model_store = { |
|
"small_llm": { |
|
"escalation_order": 0, |
|
"url": "http://localhost:1234/v1", |
|
"api_key": "lm-studio", |
|
"model_id": "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", |
|
"max_ctx": 4096 |
|
}, |
|
"large_llm": { |
|
"escalation_order": 1, |
|
"url": "http://localhost:1234/v1", |
|
"api_key": "lm-studio", |
|
"model_id": "lmstudio-community/Meta-Llama-3-70B-Instruct-GGUF/Meta-Llama-3-70B-Instruct-Q4_K_M.gguf", |
|
"max_ctx": 8192 |
|
} |
|
} |
|
|
|
def prompt_classifier(user_prompt): |
|
return sentence_classifier(user_prompt)[0]['label'] |
|
|
|
def llm_router(user_prompt, tokens_so_far = 0): |
|
return model_store[prompt_classifier(user_prompt)] |
|
|
|
def chat(user_prompt, model_store_entry = None, curr_ctx = [], system_prompt = ' ', verbose=False): |
|
if model_store_entry == None and curr_ctx == []: |
|
# initial model selection |
|
model_store_entry = llm_router(user_prompt) |
|
if verbose: |
|
print(f'Classify prompt - selected model: {model_store_entry["model_id"]}') |
|
else: |
|
#handle escalation |
|
model_store_candidate = llm_router(user_prompt) |
|
if model_store_candidate["escalation_order"] > model_store_entry["escalation_order"]: |
|
model_store_entry = model_store_candidate |
|
if verbose: |
|
print(f'Escalate model - selected model: {model_store_entry["model_id"]}') |
|
url = model_store_entry['url'] |
|
api_key = model_store_entry['api_key'] |
|
model_id = model_store_entry['model_id'] |
|
|
|
client = OpenAI(base_url=url, api_key=api_key) |
|
messages = curr_ctx |
|
messages.append({"role": "user", "content": user_prompt}) |
|
|
|
completion = client.chat.completions.create( |
|
model=model_id, |
|
messages = messages, |
|
temperature=0.7, |
|
) |
|
messages.append({"role": "assistant", "content": completion.choices[0].message.content}) |
|
if verbose: |
|
print(f'Used model: {model_id}') |
|
print(f'completion: {completion}') |
|
client.close() |
|
return completion.choices[0].message.content, messages, model_store_entry |
|
|
|
use_model = None |
|
ctx = [] |
|
# start with simple prompt -> llama3-8b |
|
res, ctx, use_model = chat(user_prompt="hello", model_store_entry=use_model, curr_ctx=ctx, verbose=True) |
|
|
|
# escalate prompt -> llama3-70b |
|
p = "Discuss the challenges and potential solutions for achieving sustainable development in the context of increasing global urbanization." |
|
res, ctx, use_model = chat(user_prompt=p, model_store_entry=use_model, curr_ctx=ctx, verbose=True) |
|
``` |