csabakecskemeti commited on
Commit
02d846c
1 Parent(s): c95c32b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +90 -1
README.md CHANGED
@@ -9,4 +9,93 @@ language:
9
  ---
10
 
11
  Intention of the model is to determine if the given user prompt's complexity, domain question requires a SOTA (very large) LLM
12
- or can be deescaleted to a smaller or local model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ---
10
 
11
  Intention of the model is to determine if the given user prompt's complexity, domain question requires a SOTA (very large) LLM
12
+ or can be deescaleted to a smaller or local model.
13
+
14
+ Example code:
15
+
16
+ ```
17
+ from openai import OpenAI
18
+ from datasets import load_dataset
19
+ from datasets.dataset_dict import DatasetDict
20
+ import json
21
+ import random
22
+ from transformers import (
23
+ RobertaTokenizerFast,
24
+ RobertaForSequenceClassification,
25
+ )
26
+ from transformers import pipeline
27
+
28
+ model_id = 'DevQuasar/roberta-prompt_classifier-v0.1'
29
+ tokenizer = RobertaTokenizerFast.from_pretrained(model_id)
30
+ sentence_classifier = pipeline(
31
+ "sentiment-analysis", model=model_id, tokenizer=tokenizer
32
+ )
33
+
34
+ model_store = {
35
+ "small_llm": {
36
+ "escalation_order": 0,
37
+ "url": "http://localhost:1234/v1",
38
+ "api_key": "lm-studio",
39
+ "model_id": "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
40
+ "max_ctx": 4096
41
+ },
42
+ "large_llm": {
43
+ "escalation_order": 1,
44
+ "url": "http://localhost:1234/v1",
45
+ "api_key": "lm-studio",
46
+ "model_id": "lmstudio-community/Meta-Llama-3-70B-Instruct-GGUF/Meta-Llama-3-70B-Instruct-Q4_K_M.gguf",
47
+ "max_ctx": 8192
48
+ }
49
+ }
50
+
51
+ def prompt_classifier(user_prompt):
52
+ return sentence_classifier(user_prompt)[0]['label']
53
+
54
+ def llm_router(user_prompt, tokens_so_far = 0):
55
+ return model_store[prompt_classifier(user_prompt)]
56
+
57
+ def chat(user_prompt, model_store_entry = None, curr_ctx = [], system_prompt = ' ', verbose=False):
58
+ if model_store_entry == None and curr_ctx == []:
59
+ model_store_entry = llm_router(user_prompt)
60
+ if verbose:
61
+ print(f'Classify prompt - selected model: {model_store_entry["model_id"]}')
62
+ else:
63
+ #handle escalation
64
+ model_store_candidate = llm_router(user_prompt)
65
+ if model_store_candidate["escalation_order"] > model_store_entry["escalation_order"]:
66
+ model_store_entry = model_store_candidate
67
+ if verbose:
68
+ print(f'Escalate model - selected model: {model_store_entry["model_id"]}')
69
+ url = model_store_entry['url']
70
+ api_key = model_store_entry['api_key']
71
+ model_id = model_store_entry['model_id']
72
+ # max_ctx = model_store_entry['max_ctx']
73
+
74
+ client = OpenAI(base_url=url, api_key=api_key)
75
+ # print(curr_ctx)
76
+ messages = curr_ctx
77
+ # print(messages)
78
+ messages.append({"role": "user", "content": user_prompt})
79
+
80
+ completion = client.chat.completions.create(
81
+ model=model_id,
82
+ messages = messages,
83
+ temperature=0.7,
84
+ )
85
+ messages.append({"role": "assistant", "content": completion.choices[0].message.content})
86
+ if verbose:
87
+ print(f'Used model: {model_id}')
88
+ print(f'completion: {completion}')
89
+ # return completion.choices[0].message.content
90
+ client.close()
91
+ return completion.choices[0].message.content, messages, model_store_entry
92
+
93
+ use_model = None
94
+ ctx = []
95
+ # start with simple prompt -> llama3-8b
96
+ res, ctx, use_model = chat(user_prompt="hello", model_store_entry=use_model, curr_ctx=ctx, verbose=True)
97
+
98
+ # escalate prompt -> llama3-70b
99
+ p = "Discuss the challenges and potential solutions for achieving sustainable development in the context of increasing global urbanization."
100
+ res, ctx, use_model = chat(user_prompt=p, model_store_entry=use_model, curr_ctx=ctx, verbose=True)
101
+ ```