Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	side by side
Browse files- .gitignore +1 -0
 - __pycache__/constant.cpython-311.pyc +0 -0
 - __pycache__/utils.cpython-311.pyc +0 -0
 - app.py +139 -51
 - app_single.py +117 -0
 - constant.py +54 -0
 - list_models.py +24 -0
 - together_model_ids.json +179 -0
 - utils.py +64 -30
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            __pycache__/
         
     | 
    	
        __pycache__/constant.cpython-311.pyc
    CHANGED
    
    | 
         Binary files a/__pycache__/constant.cpython-311.pyc and b/__pycache__/constant.cpython-311.pyc differ 
     | 
| 
         | 
    	
        __pycache__/utils.cpython-311.pyc
    CHANGED
    
    | 
         Binary files a/__pycache__/utils.cpython-311.pyc and b/__pycache__/utils.cpython-311.pyc differ 
     | 
| 
         | 
    	
        app.py
    CHANGED
    
    | 
         @@ -3,8 +3,8 @@ import os 
     | 
|
| 3 | 
         
             
            from typing import List
         
     | 
| 4 | 
         
             
            import logging
         
     | 
| 5 | 
         
             
            import urllib.request
         
     | 
| 6 | 
         
            -
            from utils import model_name_mapping, urial_template, openai_base_request,  
     | 
| 7 | 
         
            -
            from constant import js_code_label, HEADER_MD
         
     | 
| 8 | 
         
             
            from openai import OpenAI
         
     | 
| 9 | 
         
             
            import datetime
         
     | 
| 10 | 
         
             
            # add logging info to console 
         
     | 
| 
         @@ -19,28 +19,49 @@ STOP_STRS = ['"""', '# Query:', '# Answer:'] 
     | 
|
| 19 | 
         
             
            addr_limit_counter = {}
         
     | 
| 20 | 
         
             
            LAST_UPDATE_TIME = datetime.datetime.now() 
         
     | 
| 21 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 22 | 
         
             
            def respond(
         
     | 
| 23 | 
         
             
                message,
         
     | 
| 24 | 
         
             
                history: list[tuple[str, str]],
         
     | 
| 25 | 
         
             
                max_tokens,
         
     | 
| 26 | 
         
             
                temperature,
         
     | 
| 27 | 
         
             
                top_p,
         
     | 
| 28 | 
         
            -
                rp,
         
     | 
| 29 | 
         
             
                model_name,
         
     | 
| 30 | 
         
            -
                 
     | 
| 
         | 
|
| 31 | 
         
             
                request:gr.Request
         
     | 
| 32 | 
         
             
            ):  
         
     | 
| 33 | 
         
            -
                global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
         
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 36 | 
         | 
| 37 | 
         
             
                # _model_name = "meta-llama/Llama-3-8b-hf"
         
     | 
| 38 | 
         
             
                _model_name = model_name_mapping(model_name)
         
     | 
| 39 | 
         | 
| 40 | 
         
            -
                if  
     | 
| 41 | 
         
            -
                    api_key =  
     | 
| 42 | 
         
             
                else:
         
     | 
| 43 | 
         
            -
                    api_key =  
     | 
| 44 | 
         | 
| 45 | 
         
             
                # headers = request.headers
         
     | 
| 46 | 
         
             
                # if already 24 hours passed, reset the counter
         
     | 
| 
         @@ -53,12 +74,21 @@ def respond( 
     | 
|
| 53 | 
         
             
                if addr_limit_counter[host_addr] > 100:
         
     | 
| 54 | 
         
             
                    return "You have reached the limit of 100 requests for today. Please use your own API key."
         
     | 
| 55 | 
         | 
| 56 | 
         
            -
                 
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 62 | 
         
             
                addr_limit_counter[host_addr] += 1
         
     | 
| 63 | 
         
             
                logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
         
     | 
| 64 | 
         
             
                logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
         
     | 
| 
         @@ -66,45 +96,103 @@ def respond( 
     | 
|
| 66 | 
         
             
                response = ""
         
     | 
| 67 | 
         
             
                for msg in infer_request:
         
     | 
| 68 | 
         
             
                    # print(msg.choices[0].delta.keys())
         
     | 
| 69 | 
         
            -
                     
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 74 | 
         
             
                            break
         
     | 
| 75 | 
         
            -
                    if  
     | 
| 76 | 
         
            -
                         
     | 
| 77 | 
         
             
                    response += token
         
     | 
| 78 | 
         
            -
                    if  
     | 
| 79 | 
         
            -
                         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
                         
     | 
| 82 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 83 | 
         | 
| 84 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 85 | 
         
             
                with gr.Row():
         
     | 
| 86 | 
         
            -
                     
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                        model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", 
         
     | 
| 89 | 
         
            -
                                               "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
         
     | 
| 90 | 
         
            -
                                              , value="Llama-3-8B", label="Base LLM name")
         
     | 
| 91 | 
         
            -
                    with gr.Column():
         
     | 
| 92 | 
         
            -
                        together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key")
         
     | 
| 93 | 
         
            -
                        with gr.Column():
         
     | 
| 94 | 
         
            -
                            with gr.Row():
         
     | 
| 95 | 
         
            -
                                max_tokens = gr.Textbox(value=256, label="Max tokens")
         
     | 
| 96 | 
         
            -
                                temperature = gr.Textbox(value=0.5, label="Temperature")
         
     | 
| 97 | 
         
            -
                                top_p = gr.Textbox(value=0.9, label="Top-p")
         
     | 
| 98 | 
         
            -
                                rp = gr.Textbox(value=1.1, label="Repetition penalty")
         
     | 
| 99 | 
         
            -
                chat = gr.ChatInterface(
         
     | 
| 100 | 
         
            -
                    respond,
         
     | 
| 101 | 
         
            -
                    additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
         
     | 
| 102 | 
         
            -
                    # additional_inputs_accordion="⚙️ Parameters",
         
     | 
| 103 | 
         
            -
                    # fill_height=True, 
         
     | 
| 104 | 
         
            -
                )
         
     | 
| 105 | 
         
            -
                chat.chatbot.label="Chat with Base LLMs via URIAL"
         
     | 
| 106 | 
         
            -
                chat.chatbot.height = 550
         
     | 
| 107 | 
         
            -
                chat.chatbot.show_copy_button = True 
         
     | 
| 108 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 109 | 
         
             
            if __name__ == "__main__": 
         
     | 
| 110 | 
         
             
                demo.launch(show_api=False)
         
     | 
| 
         | 
|
| 3 | 
         
             
            from typing import List
         
     | 
| 4 | 
         
             
            import logging
         
     | 
| 5 | 
         
             
            import urllib.request
         
     | 
| 6 | 
         
            +
            from utils import model_name_mapping, urial_template, openai_base_request, chat_template, openai_chat_request
         
     | 
| 7 | 
         
            +
            from constant import js_code_label, HEADER_MD, BASE_TO_ALIGNED, MODELS
         
     | 
| 8 | 
         
             
            from openai import OpenAI
         
     | 
| 9 | 
         
             
            import datetime
         
     | 
| 10 | 
         
             
            # add logging info to console 
         
     | 
| 
         | 
|
| 19 | 
         
             
            addr_limit_counter = {}
         
     | 
| 20 | 
         
             
            LAST_UPDATE_TIME = datetime.datetime.now() 
         
     | 
| 21 | 
         | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            models = MODELS
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            # mega_hist = {
         
     | 
| 27 | 
         
            +
            #     "base": [],
         
     | 
| 28 | 
         
            +
            #     "aligned": []
         
     | 
| 29 | 
         
            +
            # }
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
             
            def respond(
         
     | 
| 32 | 
         
             
                message,
         
     | 
| 33 | 
         
             
                history: list[tuple[str, str]],
         
     | 
| 34 | 
         
             
                max_tokens,
         
     | 
| 35 | 
         
             
                temperature,
         
     | 
| 36 | 
         
             
                top_p,
         
     | 
| 37 | 
         
            +
                rp, 
         
     | 
| 38 | 
         
             
                model_name,
         
     | 
| 39 | 
         
            +
                model_type,
         
     | 
| 40 | 
         
            +
                api_key,
         
     | 
| 41 | 
         
             
                request:gr.Request
         
     | 
| 42 | 
         
             
            ):  
         
     | 
| 43 | 
         
            +
                global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter 
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                assert model_type in ["base", "aligned"]
         
     | 
| 46 | 
         
            +
                # if history:
         
     | 
| 47 | 
         
            +
                #     if model_type == "base":
         
     | 
| 48 | 
         
            +
                #         mega_hist["base"] = history
         
     | 
| 49 | 
         
            +
                #     else:
         
     | 
| 50 | 
         
            +
                #         mega_hist["aligned"] = history
         
     | 
| 51 | 
         
            +
                
         
     | 
| 52 | 
         
            +
                
         
     | 
| 53 | 
         
            +
                if model_type == "base":
         
     | 
| 54 | 
         
            +
                    prompt = urial_template(urial_prompt, history, message)
         
     | 
| 55 | 
         
            +
                else:
         
     | 
| 56 | 
         
            +
                    messages = chat_template(history, message)
         
     | 
| 57 | 
         | 
| 58 | 
         
             
                # _model_name = "meta-llama/Llama-3-8b-hf"
         
     | 
| 59 | 
         
             
                _model_name = model_name_mapping(model_name)
         
     | 
| 60 | 
         | 
| 61 | 
         
            +
                if api_key and len(api_key) == 64:
         
     | 
| 62 | 
         
            +
                    api_key = api_key
         
     | 
| 63 | 
         
             
                else:
         
     | 
| 64 | 
         
            +
                    api_key = None
         
     | 
| 65 | 
         | 
| 66 | 
         
             
                # headers = request.headers
         
     | 
| 67 | 
         
             
                # if already 24 hours passed, reset the counter
         
     | 
| 
         | 
|
| 74 | 
         
             
                if addr_limit_counter[host_addr] > 100:
         
     | 
| 75 | 
         
             
                    return "You have reached the limit of 100 requests for today. Please use your own API key."
         
     | 
| 76 | 
         | 
| 77 | 
         
            +
                if model_type == "base":
         
     | 
| 78 | 
         
            +
                    infer_request = openai_base_request(prompt=prompt, model=_model_name, 
         
     | 
| 79 | 
         
            +
                                                temperature=temperature, 
         
     | 
| 80 | 
         
            +
                                                max_tokens=max_tokens, 
         
     | 
| 81 | 
         
            +
                                                top_p=top_p, 
         
     | 
| 82 | 
         
            +
                                                repetition_penalty=rp,
         
     | 
| 83 | 
         
            +
                                                stop=STOP_STRS, api_key=api_key)  
         
     | 
| 84 | 
         
            +
                else:
         
     | 
| 85 | 
         
            +
                    infer_request = openai_chat_request(messages=messages, model=_model_name, 
         
     | 
| 86 | 
         
            +
                                                temperature=temperature, 
         
     | 
| 87 | 
         
            +
                                                max_tokens=max_tokens, 
         
     | 
| 88 | 
         
            +
                                                top_p=top_p, 
         
     | 
| 89 | 
         
            +
                                                repetition_penalty=rp,
         
     | 
| 90 | 
         
            +
                                                stop=STOP_STRS, api_key=api_key)
         
     | 
| 91 | 
         
            +
                    
         
     | 
| 92 | 
         
             
                addr_limit_counter[host_addr] += 1
         
     | 
| 93 | 
         
             
                logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
         
     | 
| 94 | 
         
             
                logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
         
     | 
| 
         | 
|
| 96 | 
         
             
                response = ""
         
     | 
| 97 | 
         
             
                for msg in infer_request:
         
     | 
| 98 | 
         
             
                    # print(msg.choices[0].delta.keys())
         
     | 
| 99 | 
         
            +
                    if hasattr(msg.choices[0], "delta"):
         
     | 
| 100 | 
         
            +
                        # Note: 'ChoiceDelta' object may or may not be not subscriptable
         
     | 
| 101 | 
         
            +
                        if "content" in msg.choices[0].delta:
         
     | 
| 102 | 
         
            +
                            token = msg.choices[0].delta["content"]
         
     | 
| 103 | 
         
            +
                        else:
         
     | 
| 104 | 
         
            +
                            token = msg.choices[0].delta.content
         
     | 
| 105 | 
         
            +
                    else:
         
     | 
| 106 | 
         
            +
                        token = msg.choices[0].text
         
     | 
| 107 | 
         
            +
                    if model_type == "base":
         
     | 
| 108 | 
         
            +
                        should_stop = False
         
     | 
| 109 | 
         
            +
                        for _stop in STOP_STRS:
         
     | 
| 110 | 
         
            +
                            if _stop in response + token:
         
     | 
| 111 | 
         
            +
                                should_stop = True
         
     | 
| 112 | 
         
            +
                                break
         
     | 
| 113 | 
         
            +
                        if should_stop:
         
     | 
| 114 | 
         
             
                            break
         
     | 
| 115 | 
         
            +
                    if token is None:
         
     | 
| 116 | 
         
            +
                        continue 
         
     | 
| 117 | 
         
             
                    response += token
         
     | 
| 118 | 
         
            +
                    if model_type == "base":
         
     | 
| 119 | 
         
            +
                        if response.endswith('\n"'):
         
     | 
| 120 | 
         
            +
                            response = response[:-1]
         
     | 
| 121 | 
         
            +
                        elif response.endswith('\n""'):
         
     | 
| 122 | 
         
            +
                            response = response[:-2]
         
     | 
| 123 | 
         
            +
                    yield history + [(message, response)]
         
     | 
| 124 | 
         
            +
                # mega_hist[model_type].append((message, response))
         
     | 
| 125 | 
         
            +
                # yield mega_hist[model_type]
         
     | 
| 126 | 
         
            +
             
         
     | 
| 127 | 
         | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            def load_models(base_model_name):
         
     | 
| 130 | 
         
            +
                print(f"base_model_name={base_model_name}")
         
     | 
| 131 | 
         
            +
                out_box = [gr.Chatbot(), gr.Chatbot(), gr.Dropdown()]  
         
     | 
| 132 | 
         
            +
                out_box[0] = (gr.update(label=f"Chat with Base LLM: {base_model_name}"))
         
     | 
| 133 | 
         
            +
                aligned_model_name = BASE_TO_ALIGNED[base_model_name]
         
     | 
| 134 | 
         
            +
                out_box[1] = (gr.update(label=f"Chat with Aligned LLM: {aligned_model_name}"))
         
     | 
| 135 | 
         
            +
                out_box[2] = (gr.update(value=aligned_model_name, interactive=False))
         
     | 
| 136 | 
         
            +
                return out_box[0], out_box[1], out_box[2]
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
            def clear_fn():
         
     | 
| 139 | 
         
            +
                # mega_hist["base"] = []
         
     | 
| 140 | 
         
            +
                # mega_hist["aligned"] = []
         
     | 
| 141 | 
         
            +
                return None, None, None
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    
         
     | 
| 144 | 
         
            +
            with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:  
         
     | 
| 145 | 
         
            +
                api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
         
     | 
| 146 | 
         
            +
                
         
     | 
| 147 | 
         
            +
                gr.Markdown(HEADER_MD)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
             
                with gr.Row():
         
     | 
| 150 | 
         
            +
                    chat_a = gr.Chatbot(height=500, label="Chat with Base LLMs via URIAL")
         
     | 
| 151 | 
         
            +
                    chat_b = gr.Chatbot(height=500, label="Chat with Aligned LLMs")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 152 | 
         | 
| 153 | 
         
            +
                with gr.Group():
         
     | 
| 154 | 
         
            +
                    with gr.Row():
         
     | 
| 155 | 
         
            +
                        with gr.Column(scale=2):
         
     | 
| 156 | 
         
            +
                            message = gr.Textbox(label="Prompt", placeholder="Enter your message here")
         
     | 
| 157 | 
         
            +
                            with gr.Row(): 
         
     | 
| 158 | 
         
            +
                                with gr.Column(scale=2):
         
     | 
| 159 | 
         
            +
                                    with gr.Row():
         
     | 
| 160 | 
         
            +
                                        left_model_choice = gr.Dropdown(label="Base Model", choices=models, interactive=True)
         
     | 
| 161 | 
         
            +
                                        right_model_choice = gr.Textbox(label="Aligned Model", placeholder="xxx", visible=True)
         
     | 
| 162 | 
         
            +
                                    with gr.Row():
         
     | 
| 163 | 
         
            +
                                        btn = gr.Button("🚀 Chat")
         
     | 
| 164 | 
         
            +
                                    # gr.Markdown("---")
         
     | 
| 165 | 
         
            +
                                    with gr.Row():
         
     | 
| 166 | 
         
            +
                                        stop_btn = gr.Button("⏸️ Stop")
         
     | 
| 167 | 
         
            +
                                        clear_btn = gr.Button("🔁 Clear")
         
     | 
| 168 | 
         
            +
                                    with gr.Row():
         
     | 
| 169 | 
         
            +
                                        gr.Markdown("We thank for the support from [Hyperbolic AI](https://hyperbolic.xyz/).")
         
     | 
| 170 | 
         
            +
                        with gr.Column(scale=1):
         
     | 
| 171 | 
         
            +
                            with gr.Accordion("⚙️ Params for **Base** LLM", open=True):
         
     | 
| 172 | 
         
            +
                                with gr.Row():
         
     | 
| 173 | 
         
            +
                                    max_tokens_1 = gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=2048, step=16, interactive=True, visible=True)
         
     | 
| 174 | 
         
            +
                                    temperature_1 = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
         
     | 
| 175 | 
         
            +
                                with gr.Row():
         
     | 
| 176 | 
         
            +
                                    top_p_1 = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
         
     | 
| 177 | 
         
            +
                                    rp_1 = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.1) 
         
     | 
| 178 | 
         
            +
                            with gr.Accordion("⚙️ Params for **Aligned** LLM", open=True):
         
     | 
| 179 | 
         
            +
                                with gr.Row():
         
     | 
| 180 | 
         
            +
                                    max_tokens_2 = gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=2048, step=16, interactive=True, visible=True)
         
     | 
| 181 | 
         
            +
                                    temperature_2 = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
         
     | 
| 182 | 
         
            +
                                with gr.Row():
         
     | 
| 183 | 
         
            +
                                    top_p_2 = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
         
     | 
| 184 | 
         
            +
                                    rp_2 = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0) 
         
     | 
| 185 | 
         
            +
             
         
     | 
| 186 | 
         
            +
                left_model_choice.change(load_models, [left_model_choice], [chat_a, chat_b, right_model_choice]) 
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                model_type_left = gr.Textbox(visible=False, value="base")
         
     | 
| 189 | 
         
            +
                model_type_right = gr.Textbox(visible=False, value="aligned")
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                go1 = btn.click(respond, [message, chat_a, max_tokens_1, temperature_1, top_p_1, rp_1, left_model_choice, model_type_left, api_key], chat_a)
         
     | 
| 192 | 
         
            +
                go2 = btn.click(respond, [message, chat_b, max_tokens_2, temperature_2, top_p_2, rp_2, right_model_choice, model_type_right, api_key], chat_b)
         
     | 
| 193 | 
         
            +
                
         
     | 
| 194 | 
         
            +
                stop_btn.click(None, None, None, cancels=[go1, go2])
         
     | 
| 195 | 
         
            +
                clear_btn.click(clear_fn, None, [message, chat_a, chat_b])
         
     | 
| 196 | 
         
            +
                
         
     | 
| 197 | 
         
             
            if __name__ == "__main__": 
         
     | 
| 198 | 
         
             
                demo.launch(show_api=False)
         
     | 
    	
        app_single.py
    ADDED
    
    | 
         @@ -0,0 +1,117 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr 
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            from typing import List
         
     | 
| 4 | 
         
            +
            import logging
         
     | 
| 5 | 
         
            +
            import urllib.request
         
     | 
| 6 | 
         
            +
            from utils import model_name_mapping, urial_template, openai_base_request
         
     | 
| 7 | 
         
            +
            from constant import js_code_label, HEADER_MD
         
     | 
| 8 | 
         
            +
            from openai import OpenAI
         
     | 
| 9 | 
         
            +
            import datetime
         
     | 
| 10 | 
         
            +
            # add logging info to console 
         
     | 
| 11 | 
         
            +
            logging.basicConfig(level=logging.INFO)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            URIAL_VERSION = "inst_1k_v4.help"
         
     | 
| 14 | 
         
            +
            URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
         
     | 
| 15 | 
         
            +
            urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
         
     | 
| 16 | 
         
            +
            urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
         
     | 
| 17 | 
         
            +
            STOP_STRS = ['"""', '# Query:', '# Answer:']
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            addr_limit_counter = {}
         
     | 
| 20 | 
         
            +
            LAST_UPDATE_TIME = datetime.datetime.now() 
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def respond(
         
     | 
| 24 | 
         
            +
                message,
         
     | 
| 25 | 
         
            +
                history: list[tuple[str, str]],
         
     | 
| 26 | 
         
            +
                max_tokens,
         
     | 
| 27 | 
         
            +
                temperature,
         
     | 
| 28 | 
         
            +
                top_p,
         
     | 
| 29 | 
         
            +
                rp,
         
     | 
| 30 | 
         
            +
                model_name,
         
     | 
| 31 | 
         
            +
                api_key,
         
     | 
| 32 | 
         
            +
                request:gr.Request
         
     | 
| 33 | 
         
            +
            ):  
         
     | 
| 34 | 
         
            +
                global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
         
     | 
| 35 | 
         
            +
                rp = 1.0
         
     | 
| 36 | 
         
            +
                prompt = urial_template(urial_prompt, history, message)
         
     | 
| 37 | 
         
            +
                
         
     | 
| 38 | 
         
            +
                # _model_name = "meta-llama/Llama-3-8b-hf"
         
     | 
| 39 | 
         
            +
                _model_name = model_name_mapping(model_name)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                if api_key and len(api_key) == 64:
         
     | 
| 42 | 
         
            +
                    api_key = api_key
         
     | 
| 43 | 
         
            +
                else:
         
     | 
| 44 | 
         
            +
                    api_key = None
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                # headers = request.headers
         
     | 
| 47 | 
         
            +
                # if already 24 hours passed, reset the counter
         
     | 
| 48 | 
         
            +
                if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
         
     | 
| 49 | 
         
            +
                    addr_limit_counter = {}
         
     | 
| 50 | 
         
            +
                    LAST_UPDATE_TIME = datetime.datetime.now()
         
     | 
| 51 | 
         
            +
                host_addr = request.client.host
         
     | 
| 52 | 
         
            +
                if host_addr not in addr_limit_counter:
         
     | 
| 53 | 
         
            +
                    addr_limit_counter[host_addr] = 0
         
     | 
| 54 | 
         
            +
                if addr_limit_counter[host_addr] > 100:
         
     | 
| 55 | 
         
            +
                    return "You have reached the limit of 100 requests for today. Please use your own API key."
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                infer_request = openai_base_request(prompt=prompt, model=_model_name, 
         
     | 
| 58 | 
         
            +
                                               temperature=temperature, 
         
     | 
| 59 | 
         
            +
                                               max_tokens=max_tokens, 
         
     | 
| 60 | 
         
            +
                                               top_p=top_p, 
         
     | 
| 61 | 
         
            +
                                               repetition_penalty=rp,
         
     | 
| 62 | 
         
            +
                                               stop=STOP_STRS, api_key=api_key)  
         
     | 
| 63 | 
         
            +
                addr_limit_counter[host_addr] += 1
         
     | 
| 64 | 
         
            +
                logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
         
     | 
| 65 | 
         
            +
                logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                response = ""
         
     | 
| 68 | 
         
            +
                for msg in infer_request:
         
     | 
| 69 | 
         
            +
                    # print(msg.choices[0].delta.keys())
         
     | 
| 70 | 
         
            +
                    if hasattr(msg.choices[0], "delta"):
         
     | 
| 71 | 
         
            +
                        token = msg.choices[0].delta["content"]
         
     | 
| 72 | 
         
            +
                    else:
         
     | 
| 73 | 
         
            +
                        token = msg.choices[0].text
         
     | 
| 74 | 
         
            +
                    should_stop = False
         
     | 
| 75 | 
         
            +
                    for _stop in STOP_STRS:
         
     | 
| 76 | 
         
            +
                        if _stop in response + token:
         
     | 
| 77 | 
         
            +
                            should_stop = True
         
     | 
| 78 | 
         
            +
                            break
         
     | 
| 79 | 
         
            +
                    if should_stop:
         
     | 
| 80 | 
         
            +
                        break
         
     | 
| 81 | 
         
            +
                    response += token
         
     | 
| 82 | 
         
            +
                    if response.endswith('\n"'):
         
     | 
| 83 | 
         
            +
                        response = response[:-1]
         
     | 
| 84 | 
         
            +
                    elif response.endswith('\n""'):
         
     | 
| 85 | 
         
            +
                        response = response[:-2]
         
     | 
| 86 | 
         
            +
                    yield response
         
     | 
| 87 | 
         
            +
             
         
     | 
| 88 | 
         
            +
            with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
         
     | 
| 89 | 
         
            +
                with gr.Row():
         
     | 
| 90 | 
         
            +
                    with gr.Column():
         
     | 
| 91 | 
         
            +
                        gr.Markdown(HEADER_MD)
         
     | 
| 92 | 
         
            +
                        model_name = gr.Radio(["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B", 
         
     | 
| 93 | 
         
            +
                                               "Mistral-7B-v0.1", 
         
     | 
| 94 | 
         
            +
                                               "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
         
     | 
| 95 | 
         
            +
                                              , value="Llama-3.1-405B-FP8", label="Base LLM name")
         
     | 
| 96 | 
         
            +
                    with gr.Column():
         
     | 
| 97 | 
         
            +
                        api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
         
     | 
| 98 | 
         
            +
                        # with gr.Column():
         
     | 
| 99 | 
         
            +
                        with gr.Accordion("⚙️ Parameters for Base LLM", open=True):
         
     | 
| 100 | 
         
            +
                            with gr.Row():
         
     | 
| 101 | 
         
            +
                                max_tokens = gr.Textbox(value=256, label="Max tokens")
         
     | 
| 102 | 
         
            +
                                temperature = gr.Textbox(value=0.5, label="Temperature")
         
     | 
| 103 | 
         
            +
                                top_p = gr.Textbox(value=0.9, label="Top-p")
         
     | 
| 104 | 
         
            +
                                rp = gr.Textbox(value=1.1, label="Repetition penalty")
         
     | 
| 105 | 
         
            +
                # with gr.Row():            
         
     | 
| 106 | 
         
            +
                chat = gr.ChatInterface(
         
     | 
| 107 | 
         
            +
                    respond,
         
     | 
| 108 | 
         
            +
                    additional_inputs=[max_tokens, temperature, top_p, rp, model_name, api_key],
         
     | 
| 109 | 
         
            +
                    # additional_inputs_accordion="⚙️ Parameters",
         
     | 
| 110 | 
         
            +
                    # fill_height=True, 
         
     | 
| 111 | 
         
            +
                )
         
     | 
| 112 | 
         
            +
                chat.chatbot.label="Chat with Base LLMs via URIAL"
         
     | 
| 113 | 
         
            +
                chat.chatbot.height = 550
         
     | 
| 114 | 
         
            +
                chat.chatbot.show_copy_button = True  
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            if __name__ == "__main__": 
         
     | 
| 117 | 
         
            +
                demo.launch(show_api=False)
         
     | 
    	
        constant.py
    CHANGED
    
    | 
         @@ -33,3 +33,57 @@ function addApiKeyLink() { 
     | 
|
| 33 | 
         
             
                }
         
     | 
| 34 | 
         
             
            }
         
     | 
| 35 | 
         
             
            """
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 33 | 
         
             
                }
         
     | 
| 34 | 
         
             
            }
         
     | 
| 35 | 
         
             
            """
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            MODELS = ["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B", 
         
     | 
| 39 | 
         
            +
                        "Mistral-7B-v0.1", 
         
     | 
| 40 | 
         
            +
                        "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMo-7B"]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            HYPERBOLIC_MODELS = ["meta-llama/Meta-Llama-3.1-405B-FP8", "meta-llama/Meta-Llama-3.1-405B-Instruct"]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            BASE_TO_ALIGNED = {
         
     | 
| 45 | 
         
            +
                "Llama-3-70B": "Llama-3-70B-Instruct",
         
     | 
| 46 | 
         
            +
                "Llama-3-8B": "Llama-3-8B-Instruct",
         
     | 
| 47 | 
         
            +
                "Mistral-7B-v0.1": "Mistral-7B-v0.1-Instruct",
         
     | 
| 48 | 
         
            +
                "Mixtral-8x22B": "Mixtral-8x22B-Instruct",
         
     | 
| 49 | 
         
            +
                "Qwen1.5-72B": "Qwen1.5-72B-Instruct",
         
     | 
| 50 | 
         
            +
                "Llama-3.1-405B-FP8": "Llama-3.1-405B-FP8-Instruct",
         
     | 
| 51 | 
         
            +
                "Yi-34B": "Yi-34B-chat",
         
     | 
| 52 | 
         
            +
                "Llama-2-7B": "Llama-2-7B-chat",
         
     | 
| 53 | 
         
            +
                "Llama-2-70B": "Llama-2-70B-chat",
         
     | 
| 54 | 
         
            +
                "OLMo-7B": "OLMo-7B-Instruct", 
         
     | 
| 55 | 
         
            +
            }
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            MODEL_MAPPING = {
         
     | 
| 59 | 
         
            +
                "Llama-3-8B": "meta-llama/Llama-3-8b-hf",
         
     | 
| 60 | 
         
            +
                "Llama-3-70B": "meta-llama/Llama-3-70b-hf",
         
     | 
| 61 | 
         
            +
                "Llama-2-7B": "meta-llama/Llama-2-7b-hf",
         
     | 
| 62 | 
         
            +
                "Llama-2-70B": "meta-llama/Llama-2-70b-hf",
         
     | 
| 63 | 
         
            +
                "Mistral-7B-v0.1": "mistralai/Mistral-7B-v0.1",
         
     | 
| 64 | 
         
            +
                "Mixtral-8x22B": "mistralai/Mixtral-8x22B",
         
     | 
| 65 | 
         
            +
                "Qwen1.5-72B": "Qwen/Qwen1.5-72B",
         
     | 
| 66 | 
         
            +
                "Yi-34B": "zero-one-ai/Yi-34B",
         
     | 
| 67 | 
         
            +
                "Yi-6B": "zero-one-ai/Yi-6B",
         
     | 
| 68 | 
         
            +
                "OLMo-7B": "allenai/OLMo-7B",
         
     | 
| 69 | 
         
            +
                "Llama-3.1-405B-FP8": "meta-llama/Meta-Llama-3.1-405B-FP8",
         
     | 
| 70 | 
         
            +
                #  Aligned models below 
         
     | 
| 71 | 
         
            +
                "Llama-3-70B-Instruct": "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
         
     | 
| 72 | 
         
            +
                "Llama-3-8B-Instruct": "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
         
     | 
| 73 | 
         
            +
                "Mistral-7B-v0.1-Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
         
     | 
| 74 | 
         
            +
                "Mixtral-8x22B-Instruct": "mistralai/Mixtral-8x22B-Instruct-v0.1",
         
     | 
| 75 | 
         
            +
                "Qwen1.5-72B-Instruct": "Qwen/Qwen2-72B-Instruct",
         
     | 
| 76 | 
         
            +
                "Yi-34B-chat": "zero-one-ai/Yi-34B-Chat",
         
     | 
| 77 | 
         
            +
                "Llama-2-7B-chat": "meta-llama/Llama-2-7b-chat-hf",
         
     | 
| 78 | 
         
            +
                "Llama-2-70B-chat": "meta-llama/Llama-2-70b-chat-hf",
         
     | 
| 79 | 
         
            +
                "OLMo-7B-Instruct": "allenai/OLMo-7B-Instruct",
         
     | 
| 80 | 
         
            +
                "Llama-3.1-405B-FP8-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct",
         
     | 
| 81 | 
         
            +
            }
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            # import json 
         
     | 
| 84 | 
         
            +
            # with open("together_model_ids.json", "r") as f:
         
     | 
| 85 | 
         
            +
            #     TOGETHER_MODEL_IDS = json.load(f)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            # for _, model_id in MODEL_MAPPING.items():
         
     | 
| 88 | 
         
            +
            #     if model_id not in TOGETHER_MODEL_IDS + HYPERBOLIC_MODELS:
         
     | 
| 89 | 
         
            +
            #         print(model_id)
         
     | 
    	
        list_models.py
    ADDED
    
    | 
         @@ -0,0 +1,24 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import requests
         
     | 
| 2 | 
         
            +
            import json 
         
     | 
| 3 | 
         
            +
            import os 
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            url = "https://api.together.xyz/v1/models"
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            headers = {
         
     | 
| 8 | 
         
            +
                "accept": "application/json",
         
     | 
| 9 | 
         
            +
                "Authorization": f"Bearer {os.getenv('TOGETHER_API_KEY')}"
         
     | 
| 10 | 
         
            +
            }
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            response = requests.get(url, headers=headers)
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            data = response.json()
         
     | 
| 15 | 
         
            +
            keywords = ["OLMO"]
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            model_ids = []
         
     | 
| 18 | 
         
            +
            for item in data:
         
     | 
| 19 | 
         
            +
                if any(keyword.lower() in item["id"].lower() for keyword in keywords):
         
     | 
| 20 | 
         
            +
                    print(item["id"])
         
     | 
| 21 | 
         
            +
                model_ids.append(item["id"])
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            with open("together_model_ids.json", "w") as f:
         
     | 
| 24 | 
         
            +
                json.dump(model_ids, f, indent=4)
         
     | 
    	
        together_model_ids.json
    ADDED
    
    | 
         @@ -0,0 +1,179 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            [
         
     | 
| 2 | 
         
            +
                "Nexusflow/NexusRaven-V2-13B",
         
     | 
| 3 | 
         
            +
                "bert-base-uncased",
         
     | 
| 4 | 
         
            +
                "WizardLM/WizardLM-13B-V1.2",
         
     | 
| 5 | 
         
            +
                "codellama/CodeLlama-34b-Instruct-hf",
         
     | 
| 6 | 
         
            +
                "google/gemma-7b",
         
     | 
| 7 | 
         
            +
                "upstage/SOLAR-10.7B-Instruct-v1.0",
         
     | 
| 8 | 
         
            +
                "zero-one-ai/Yi-34B",
         
     | 
| 9 | 
         
            +
                "togethercomputer/StripedHyena-Hessian-7B",
         
     | 
| 10 | 
         
            +
                "meta-llama/Llama-3-70b-chat-hf",
         
     | 
| 11 | 
         
            +
                "teknium/OpenHermes-2-Mistral-7B",
         
     | 
| 12 | 
         
            +
                "mistralai/Mixtral-8x7B-v0.1",
         
     | 
| 13 | 
         
            +
                "WhereIsAI/UAE-Large-V1",
         
     | 
| 14 | 
         
            +
                "hazyresearch/M2-BERT-2k-Retrieval-Encoder-V1",
         
     | 
| 15 | 
         
            +
                "togethercomputer/Llama-2-7B-32K-Instruct",
         
     | 
| 16 | 
         
            +
                "Undi95/ReMM-SLERP-L2-13B",
         
     | 
| 17 | 
         
            +
                "meta-llama/Meta-Llama-Guard-3-8B",
         
     | 
| 18 | 
         
            +
                "Undi95/Toppy-M-7B",
         
     | 
| 19 | 
         
            +
                "Phind/Phind-CodeLlama-34B-v2",
         
     | 
| 20 | 
         
            +
                "stabilityai/stable-diffusion-2-1",
         
     | 
| 21 | 
         
            +
                "openchat/openchat-3.5-1210",
         
     | 
| 22 | 
         
            +
                "Austism/chronos-hermes-13b",
         
     | 
| 23 | 
         
            +
                "microsoft/phi-2",
         
     | 
| 24 | 
         
            +
                "Qwen/Qwen1.5-0.5B",
         
     | 
| 25 | 
         
            +
                "Qwen/Qwen1.5-1.8B",
         
     | 
| 26 | 
         
            +
                "Qwen/Qwen1.5-4B",
         
     | 
| 27 | 
         
            +
                "Qwen/Qwen1.5-7B",
         
     | 
| 28 | 
         
            +
                "togethercomputer/m2-bert-80M-32k-retrieval",
         
     | 
| 29 | 
         
            +
                "snorkelai/Snorkel-Mistral-PairRM-DPO",
         
     | 
| 30 | 
         
            +
                "Qwen/Qwen1.5-7B-Chat",
         
     | 
| 31 | 
         
            +
                "Qwen/Qwen1.5-14B",
         
     | 
| 32 | 
         
            +
                "Qwen/Qwen1.5-14B-Chat",
         
     | 
| 33 | 
         
            +
                "Qwen/Qwen1.5-72B",
         
     | 
| 34 | 
         
            +
                "Qwen/Qwen1.5-1.8B-Chat",
         
     | 
| 35 | 
         
            +
                "BAAI/bge-base-en-v1.5",
         
     | 
| 36 | 
         
            +
                "Snowflake/snowflake-arctic-instruct",
         
     | 
| 37 | 
         
            +
                "codellama/CodeLlama-13b-Python-hf",
         
     | 
| 38 | 
         
            +
                "NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
         
     | 
| 39 | 
         
            +
                "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
         
     | 
| 40 | 
         
            +
                "togethercomputer/m2-bert-80M-2k-retrieval",
         
     | 
| 41 | 
         
            +
                "deepseek-ai/deepseek-coder-33b-instruct",
         
     | 
| 42 | 
         
            +
                "codellama/CodeLlama-34b-Python-hf",
         
     | 
| 43 | 
         
            +
                "NousResearch/Nous-Hermes-Llama2-13b",
         
     | 
| 44 | 
         
            +
                "lmsys/vicuna-13b-v1.5",
         
     | 
| 45 | 
         
            +
                "Qwen/Qwen1.5-0.5B-Chat",
         
     | 
| 46 | 
         
            +
                "codellama/CodeLlama-70b-Python-hf",
         
     | 
| 47 | 
         
            +
                "codellama/CodeLlama-7b-Instruct-hf",
         
     | 
| 48 | 
         
            +
                "NousResearch/Nous-Hermes-2-Yi-34B",
         
     | 
| 49 | 
         
            +
                "codellama/CodeLlama-13b-Instruct-hf",
         
     | 
| 50 | 
         
            +
                "BAAI/bge-large-en-v1.5",
         
     | 
| 51 | 
         
            +
                "togethercomputer/Llama-3-8b-chat-hf-int4",
         
     | 
| 52 | 
         
            +
                "meta-llama/Llama-2-13b-hf",
         
     | 
| 53 | 
         
            +
                "teknium/OpenHermes-2p5-Mistral-7B",
         
     | 
| 54 | 
         
            +
                "NousResearch/Nous-Capybara-7B-V1p9",
         
     | 
| 55 | 
         
            +
                "WizardLM/WizardCoder-Python-34B-V1.0",
         
     | 
| 56 | 
         
            +
                "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
         
     | 
| 57 | 
         
            +
                "NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
         
     | 
| 58 | 
         
            +
                "togethercomputer/StripedHyena-Nous-7B",
         
     | 
| 59 | 
         
            +
                "togethercomputer/alpaca-7b",
         
     | 
| 60 | 
         
            +
                "garage-bAInd/Platypus2-70B-instruct",
         
     | 
| 61 | 
         
            +
                "google/gemma-2b",
         
     | 
| 62 | 
         
            +
                "google/gemma-2b-it",
         
     | 
| 63 | 
         
            +
                "google/gemma-7b-it",
         
     | 
| 64 | 
         
            +
                "meta-llama/Llama-2-7b-chat-hf",
         
     | 
| 65 | 
         
            +
                "allenai/OLMo-7B",
         
     | 
| 66 | 
         
            +
                "allenai/OLMo-7B-Instruct",
         
     | 
| 67 | 
         
            +
                "Qwen/Qwen1.5-4B-Chat",
         
     | 
| 68 | 
         
            +
                "stabilityai/stable-diffusion-xl-base-1.0",
         
     | 
| 69 | 
         
            +
                "Gryphe/MythoMax-L2-13b",
         
     | 
| 70 | 
         
            +
                "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
         
     | 
| 71 | 
         
            +
                "meta-llama/LlamaGuard-2-8b",
         
     | 
| 72 | 
         
            +
                "mistralai/Mistral-7B-Instruct-v0.1",
         
     | 
| 73 | 
         
            +
                "mistralai/Mistral-7B-Instruct-v0.2",
         
     | 
| 74 | 
         
            +
                "meta-llama/Meta-Llama-3-8B",
         
     | 
| 75 | 
         
            +
                "mistralai/Mistral-7B-v0.1",
         
     | 
| 76 | 
         
            +
                "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
         
     | 
| 77 | 
         
            +
                "Open-Orca/Mistral-7B-OpenOrca",
         
     | 
| 78 | 
         
            +
                "Qwen/Qwen1.5-32B",
         
     | 
| 79 | 
         
            +
                "NousResearch/Nous-Hermes-llama-2-7b",
         
     | 
| 80 | 
         
            +
                "Qwen/Qwen1.5-32B-Chat",
         
     | 
| 81 | 
         
            +
                "mistralai/Mixtral-8x22B",
         
     | 
| 82 | 
         
            +
                "Qwen/Qwen2-72B-Instruct",
         
     | 
| 83 | 
         
            +
                "Qwen/Qwen1.5-72B-Chat",
         
     | 
| 84 | 
         
            +
                "meta-llama/Meta-Llama-3-70B",
         
     | 
| 85 | 
         
            +
                "meta-llama/Llama-3-8b-hf",
         
     | 
| 86 | 
         
            +
                "deepseek-ai/deepseek-llm-67b-chat",
         
     | 
| 87 | 
         
            +
                "sentence-transformers/msmarco-bert-base-dot-v5",
         
     | 
| 88 | 
         
            +
                "zero-one-ai/Yi-6B",
         
     | 
| 89 | 
         
            +
                "lmsys/vicuna-7b-v1.5",
         
     | 
| 90 | 
         
            +
                "togethercomputer/m2-bert-80M-8k-retrieval",
         
     | 
| 91 | 
         
            +
                "microsoft/WizardLM-2-8x22B",
         
     | 
| 92 | 
         
            +
                "togethercomputer/Llama-3-8b-chat-hf-int8",
         
     | 
| 93 | 
         
            +
                "wavymulder/Analog-Diffusion",
         
     | 
| 94 | 
         
            +
                "mistralai/Mistral-7B-Instruct-v0.3",
         
     | 
| 95 | 
         
            +
                "Qwen/Qwen1.5-110B-Chat",
         
     | 
| 96 | 
         
            +
                "runwayml/stable-diffusion-v1-5",
         
     | 
| 97 | 
         
            +
                "prompthero/openjourney",
         
     | 
| 98 | 
         
            +
                "meta-llama/Llama-2-7b-hf",
         
     | 
| 99 | 
         
            +
                "SG161222/Realistic_Vision_V3.0_VAE",
         
     | 
| 100 | 
         
            +
                "meta-llama/Llama-2-13b-chat-hf",
         
     | 
| 101 | 
         
            +
                "google/gemma-2-27b-it",
         
     | 
| 102 | 
         
            +
                "zero-one-ai/Yi-34B-Chat",
         
     | 
| 103 | 
         
            +
                "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
         
     | 
| 104 | 
         
            +
                "meta-llama/Meta-Llama-3-70B-Instruct-Lite",
         
     | 
| 105 | 
         
            +
                "google/gemma-2-9b-it",
         
     | 
| 106 | 
         
            +
                "google/gemma-2-9b",
         
     | 
| 107 | 
         
            +
                "meta-llama/Llama-3-8b-chat-hf",
         
     | 
| 108 | 
         
            +
                "mistralai/Mixtral-8x7B-Instruct-v0.1",
         
     | 
| 109 | 
         
            +
                "codellama/CodeLlama-70b-hf",
         
     | 
| 110 | 
         
            +
                "togethercomputer/LLaMA-2-7B-32K",
         
     | 
| 111 | 
         
            +
                "databricks/dbrx-instruct",
         
     | 
| 112 | 
         
            +
                "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference",
         
     | 
| 113 | 
         
            +
                "meta-llama/Meta-Llama-3-8B-Instruct-Turbo",
         
     | 
| 114 | 
         
            +
                "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
         
     | 
| 115 | 
         
            +
                "mistralai/Mixtral-8x22B-Instruct-v0.1",
         
     | 
| 116 | 
         
            +
                "togethercomputer/evo-1-131k-base",
         
     | 
| 117 | 
         
            +
                "meta-llama/Llama-2-70b-hf",
         
     | 
| 118 | 
         
            +
                "codellama/CodeLlama-70b-Instruct-hf",
         
     | 
| 119 | 
         
            +
                "meta-llama/Meta-Llama-3-8B-Instruct-Lite",
         
     | 
| 120 | 
         
            +
                "togethercomputer/evo-1-8k-base",
         
     | 
| 121 | 
         
            +
                "meta-llama/Llama-2-70b-chat-hf",
         
     | 
| 122 | 
         
            +
                "codellama/CodeLlama-7b-Python-hf",
         
     | 
| 123 | 
         
            +
                "Meta-Llama/Llama-Guard-7b",
         
     | 
| 124 | 
         
            +
                "togethercomputer/Koala-7B",
         
     | 
| 125 | 
         
            +
                "Qwen/Qwen2-1.5B-Instruct",
         
     | 
| 126 | 
         
            +
                "Qwen/Qwen2-7B-Instruct",
         
     | 
| 127 | 
         
            +
                "NousResearch/Nous-Hermes-13b",
         
     | 
| 128 | 
         
            +
                "togethercomputer/guanaco-65b",
         
     | 
| 129 | 
         
            +
                "togethercomputer/llama-2-7b",
         
     | 
| 130 | 
         
            +
                "huggyllama/llama-7b",
         
     | 
| 131 | 
         
            +
                "lmsys/vicuna-7b-v1.3",
         
     | 
| 132 | 
         
            +
                "Qwen/Qwen2-72B",
         
     | 
| 133 | 
         
            +
                "Phind/Phind-CodeLlama-34B-Python-v1",
         
     | 
| 134 | 
         
            +
                "NumbersStation/nsql-llama-2-7B",
         
     | 
| 135 | 
         
            +
                "NousResearch/Nous-Hermes-Llama2-70b",
         
     | 
| 136 | 
         
            +
                "WizardLM/WizardLM-70B-V1.0",
         
     | 
| 137 | 
         
            +
                "huggyllama/llama-65b",
         
     | 
| 138 | 
         
            +
                "lmsys/vicuna-13b-v1.5-16k",
         
     | 
| 139 | 
         
            +
                "HuggingFaceH4/zephyr-7b-beta",
         
     | 
| 140 | 
         
            +
                "togethercomputer/llama-2-13b",
         
     | 
| 141 | 
         
            +
                "togethercomputer/CodeLlama-7b-Instruct",
         
     | 
| 142 | 
         
            +
                "togethercomputer/guanaco-13b",
         
     | 
| 143 | 
         
            +
                "togethercomputer/CodeLlama-34b-Python",
         
     | 
| 144 | 
         
            +
                "togethercomputer/CodeLlama-34b-Instruct",
         
     | 
| 145 | 
         
            +
                "togethercomputer/CodeLlama-34b",
         
     | 
| 146 | 
         
            +
                "togethercomputer/llama-2-70b",
         
     | 
| 147 | 
         
            +
                "codellama/CodeLlama-13b-hf",
         
     | 
| 148 | 
         
            +
                "Qwen/Qwen2-7B",
         
     | 
| 149 | 
         
            +
                "Qwen/Qwen2-1.5B",
         
     | 
| 150 | 
         
            +
                "togethercomputer/CodeLlama-13b-Instruct",
         
     | 
| 151 | 
         
            +
                "togethercomputer/llama-2-13b-chat",
         
     | 
| 152 | 
         
            +
                "lmsys/vicuna-13b-v1.3",
         
     | 
| 153 | 
         
            +
                "huggyllama/llama-13b",
         
     | 
| 154 | 
         
            +
                "huggyllama/llama-30b",
         
     | 
| 155 | 
         
            +
                "togethercomputer/guanaco-33b",
         
     | 
| 156 | 
         
            +
                "togethercomputer/Koala-13B",
         
     | 
| 157 | 
         
            +
                "togethercomputer/llama-2-7b-chat",
         
     | 
| 158 | 
         
            +
                "togethercomputer/SOLAR-10.7B-Instruct-v1.0-int4",
         
     | 
| 159 | 
         
            +
                "togethercomputer/guanaco-7b",
         
     | 
| 160 | 
         
            +
                "EleutherAI/llemma_7b",
         
     | 
| 161 | 
         
            +
                "meta-llama/Meta-Llama-3-8B-Instruct",
         
     | 
| 162 | 
         
            +
                "codellama/CodeLlama-34b-hf",
         
     | 
| 163 | 
         
            +
                "meta-llama/Meta-Llama-3-70B-Instruct",
         
     | 
| 164 | 
         
            +
                "meta-llama/Llama-3-70b-hf",
         
     | 
| 165 | 
         
            +
                "togethercomputer/CodeLlama-7b-Python",
         
     | 
| 166 | 
         
            +
                "NousResearch/Hermes-2-Theta-Llama-3-70B",
         
     | 
| 167 | 
         
            +
                "carson/ml318bit",
         
     | 
| 168 | 
         
            +
                "togethercomputer/CodeLlama-13b-Python",
         
     | 
| 169 | 
         
            +
                "codellama/CodeLlama-7b-hf",
         
     | 
| 170 | 
         
            +
                "togethercomputer/llama-2-70b-chat",
         
     | 
| 171 | 
         
            +
                "carson/ml31405bit",
         
     | 
| 172 | 
         
            +
                "carson/ml3170bit",
         
     | 
| 173 | 
         
            +
                "carson/mlg38b",
         
     | 
| 174 | 
         
            +
                "carson/ml318br",
         
     | 
| 175 | 
         
            +
                "meta-llama/Meta-Llama-3.1-8B-Reference",
         
     | 
| 176 | 
         
            +
                "gradientai/Llama-3-70B-Instruct-Gradient-1048k",
         
     | 
| 177 | 
         
            +
                "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference",
         
     | 
| 178 | 
         
            +
                "meta-llama/Meta-Llama-3.1-70B-Reference"
         
     | 
| 179 | 
         
            +
            ]
         
     | 
    	
        utils.py
    CHANGED
    
    | 
         @@ -3,36 +3,15 @@ from openai import OpenAI 
     | 
|
| 3 | 
         
             
            import logging
         
     | 
| 4 | 
         
             
            from typing import List
         
     | 
| 5 | 
         
             
            import os 
         
     | 
| 
         | 
|
| 6 | 
         | 
| 7 | 
         
            -
            BASE_URL = "https://api.together.xyz/v1"
         
     | 
| 8 | 
         
            -
            DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            def model_name_mapping(model_name):
         
     | 
| 11 | 
         
            -
                 
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
                    _model_name = "meta-llama/Llama-3-70b-hf"
         
     | 
| 15 | 
         
            -
                elif model_name == "Llama-2-7B":
         
     | 
| 16 | 
         
            -
                    _model_name = "meta-llama/Llama-2-7b-hf"
         
     | 
| 17 | 
         
            -
                elif model_name == "Llama-2-70B":
         
     | 
| 18 | 
         
            -
                    _model_name = "meta-llama/Llama-2-70b-hf"
         
     | 
| 19 | 
         
            -
                elif model_name == "Mistral-7B-v0.1":
         
     | 
| 20 | 
         
            -
                    _model_name = "mistralai/Mistral-7B-v0.1"
         
     | 
| 21 | 
         
            -
                elif model_name == "Mixtral-8x22B":
         
     | 
| 22 | 
         
            -
                    _model_name = "mistralai/Mixtral-8x22B"
         
     | 
| 23 | 
         
            -
                elif model_name == "Qwen1.5-72B":
         
     | 
| 24 | 
         
            -
                    _model_name = "Qwen/Qwen1.5-72B"
         
     | 
| 25 | 
         
            -
                elif model_name == "Yi-34B":
         
     | 
| 26 | 
         
            -
                    _model_name = "zero-one-ai/Yi-34B"
         
     | 
| 27 | 
         
            -
                elif model_name == "Yi-6B":
         
     | 
| 28 | 
         
            -
                    _model_name = "zero-one-ai/Yi-6B"
         
     | 
| 29 | 
         
            -
                elif model_name == "OLMO":
         
     | 
| 30 | 
         
            -
                    _model_name = "allenai/OLMo-7B"
         
     | 
| 31 | 
         
            -
                elif model_name == "Qwen1.5-72B":
         
     | 
| 32 | 
         
            -
                    _model_name = "Qwen/Qwen1.5-72B"
         
     | 
| 33 | 
         
             
                else:
         
     | 
| 34 | 
         
            -
                    raise ValueError("Invalid model name")
         
     | 
| 35 | 
         
            -
                return _model_name
         
     | 
| 36 | 
         | 
| 37 | 
         | 
| 38 | 
         
             
            def urial_template(urial_prompt, history, message):
         
     | 
| 
         @@ -41,7 +20,14 @@ def urial_template(urial_prompt, history, message): 
     | 
|
| 41 | 
         
             
                    current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
         
     | 
| 42 | 
         
             
                current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
         
     | 
| 43 | 
         
             
                return current_prompt
         
     | 
| 44 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
             
            def openai_base_request(
         
     | 
| 47 | 
         
             
                model: str=None, 
         
     | 
| 
         @@ -54,11 +40,18 @@ def openai_base_request( 
     | 
|
| 54 | 
         
             
                stop: List[str]=None, 
         
     | 
| 55 | 
         
             
                api_key: str=None,
         
     | 
| 56 | 
         
             
                ):  
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 57 | 
         
             
                if api_key is None:
         
     | 
| 58 | 
         
             
                    api_key = DEFAULT_API_KEY
         
     | 
| 59 | 
         
            -
                client = OpenAI(api_key=api_key, base_url=BASE_URL)
         
     | 
| 60 | 
         
            -
                 
     | 
| 61 | 
         
            -
                logging.info(f"Requesting chat completion from OpenAI API with model {model}")
         
     | 
| 62 | 
         
             
                logging.info(f"Prompt: {prompt}")
         
     | 
| 63 | 
         
             
                logging.info(f"Temperature: {temperature}")
         
     | 
| 64 | 
         
             
                logging.info(f"Max tokens: {max_tokens}")
         
     | 
| 
         @@ -80,3 +73,44 @@ def openai_base_request( 
     | 
|
| 80 | 
         | 
| 81 | 
         
             
                return request 
         
     | 
| 82 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 3 | 
         
             
            import logging
         
     | 
| 4 | 
         
             
            from typing import List
         
     | 
| 5 | 
         
             
            import os 
         
     | 
| 6 | 
         
            +
            from constant import HYPERBOLIC_MODELS, MODEL_MAPPING
         
     | 
| 7 | 
         | 
| 
         | 
|
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            def model_name_mapping(model_name):
         
     | 
| 10 | 
         
            +
                model_mapping = MODEL_MAPPING   
         
     | 
| 11 | 
         
            +
                if model_name in model_mapping:
         
     | 
| 12 | 
         
            +
                    return model_mapping[model_name]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         
             
                else:
         
     | 
| 14 | 
         
            +
                    raise ValueError("Invalid model name:", model_name)
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
             
            def urial_template(urial_prompt, history, message):
         
     | 
| 
         | 
|
| 20 | 
         
             
                    current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
         
     | 
| 21 | 
         
             
                current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
         
     | 
| 22 | 
         
             
                return current_prompt
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def chat_template(history, message):
         
     | 
| 25 | 
         
            +
                messages = [] 
         
     | 
| 26 | 
         
            +
                for user_msg, ai_msg in history:
         
     | 
| 27 | 
         
            +
                    messages.append({"role": "user", "content": user_msg})
         
     | 
| 28 | 
         
            +
                    messages.append({"role": "assistant", "content": ai_msg})
         
     | 
| 29 | 
         
            +
                messages.append({"role": "user", "content": message})
         
     | 
| 30 | 
         
            +
                return messages
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            def openai_base_request(
         
     | 
| 33 | 
         
             
                model: str=None, 
         
     | 
| 
         | 
|
| 40 | 
         
             
                stop: List[str]=None, 
         
     | 
| 41 | 
         
             
                api_key: str=None,
         
     | 
| 42 | 
         
             
                ):  
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                if model in HYPERBOLIC_MODELS:
         
     | 
| 45 | 
         
            +
                    BASE_URL = "https://api.hyperbolic.xyz/v1"
         
     | 
| 46 | 
         
            +
                    DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
         
     | 
| 47 | 
         
            +
                else:
         
     | 
| 48 | 
         
            +
                    BASE_URL = "https://api.together.xyz/v1"
         
     | 
| 49 | 
         
            +
                    DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
             
                if api_key is None:
         
     | 
| 52 | 
         
             
                    api_key = DEFAULT_API_KEY
         
     | 
| 53 | 
         
            +
                client = OpenAI(api_key=api_key, base_url=BASE_URL) 
         
     | 
| 54 | 
         
            +
                logging.info(f"Requesting base completion from OpenAI API with model {model}")
         
     | 
| 
         | 
|
| 55 | 
         
             
                logging.info(f"Prompt: {prompt}")
         
     | 
| 56 | 
         
             
                logging.info(f"Temperature: {temperature}")
         
     | 
| 57 | 
         
             
                logging.info(f"Max tokens: {max_tokens}")
         
     | 
| 
         | 
|
| 73 | 
         | 
| 74 | 
         
             
                return request 
         
     | 
| 75 | 
         | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            def openai_chat_request(
         
     | 
| 79 | 
         
            +
                model: str=None, 
         
     | 
| 80 | 
         
            +
                temperature: float=0,
         
     | 
| 81 | 
         
            +
                max_tokens: int=512,
         
     | 
| 82 | 
         
            +
                top_p: float=1.0, 
         
     | 
| 83 | 
         
            +
                messages=None,
         
     | 
| 84 | 
         
            +
                n: int=1, 
         
     | 
| 85 | 
         
            +
                repetition_penalty: float=1.0,
         
     | 
| 86 | 
         
            +
                stop: List[str]=None, 
         
     | 
| 87 | 
         
            +
                api_key: str=None,
         
     | 
| 88 | 
         
            +
                ):  
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                if model in HYPERBOLIC_MODELS:
         
     | 
| 91 | 
         
            +
                    BASE_URL = "https://api.hyperbolic.xyz/v1"
         
     | 
| 92 | 
         
            +
                    DEFAULT_API_KEY = os.getenv("HYPERBOLIC_API_KEY")
         
     | 
| 93 | 
         
            +
                else:
         
     | 
| 94 | 
         
            +
                    BASE_URL = "https://api.together.xyz/v1"
         
     | 
| 95 | 
         
            +
                    DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                if api_key is None:
         
     | 
| 98 | 
         
            +
                    api_key = DEFAULT_API_KEY
         
     | 
| 99 | 
         
            +
                
         
     | 
| 100 | 
         
            +
                logging.info(f"Requesting chat completion from OpenAI API with model {model}")
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                client = OpenAI(api_key=api_key, base_url=BASE_URL)  
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                request = client.chat.completions.create(
         
     | 
| 105 | 
         
            +
                    model=model, 
         
     | 
| 106 | 
         
            +
                    messages=messages,
         
     | 
| 107 | 
         
            +
                    temperature=float(temperature),
         
     | 
| 108 | 
         
            +
                    max_tokens=int(max_tokens),
         
     | 
| 109 | 
         
            +
                    top_p=float(top_p),
         
     | 
| 110 | 
         
            +
                    n=n,
         
     | 
| 111 | 
         
            +
                    extra_body={'repetition_penalty': float(repetition_penalty)},
         
     | 
| 112 | 
         
            +
                    stop=stop, 
         
     | 
| 113 | 
         
            +
                    stream=True
         
     | 
| 114 | 
         
            +
                )  
         
     | 
| 115 | 
         
            +
                return request 
         
     | 
| 116 | 
         
            +
             
     |