File size: 1,359 Bytes
141a5cd
d375f40
141a5cd
3981b8f
141a5cd
 
 
7747319
 
 
 
141a5cd
7747319
 
 
 
141a5cd
 
c27096e
 
 
 
 
 
7f0ed3e
c27096e
141a5cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import torch
import google.auth
from transformers import AutoModelForCausalLM, AutoTokenizer

# https://huggingface.co/blog/inference-pro
ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS = [
    # "mistralai/Mixtral-8x7B-Instruct-v0.1",
    # "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
    # "mistralai/Mistral-7B-Instruct-v0.2",
    # "mistralai/Mistral-7B-Instruct-v0.1",
    "HuggingFaceH4/zephyr-7b-beta",
    # "meta-llama/Llama-2-7b-chat-hf",
    # "meta-llama/Llama-2-13b-chat-hf",
    # "meta-llama/Llama-2-70b-chat-hf",
    # "openchat/openchat-3.5-0106"
]

def init_model(model_id):
    global tokenizer
    global model
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda")

def auth_gcp():
    gcp_credentials = os.getenv("GCP_CREDENTIALS")
    with open("gcp-credentials.json", "w") as f:
        f.write(gcp_credentials)
        
    os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = './gcp-credentials.json'
    google.auth.default()

def get_env_vars():
    global gcp_project_id, gcp_project_location
    global hf_access_token

    gcp_project_id = os.getenv("GCP_PROJECT_ID")
    gcp_project_location = os.getenv("GCP_PROJECT_LOCATION")
    hf_access_token = os.getenv("HF_TOKEN", None)