Spaces:
Paused
Paused
import os | |
import torch | |
import google.auth | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# https://huggingface.co/blog/inference-pro | |
ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS = [ | |
# "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
# "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
# "mistralai/Mistral-7B-Instruct-v0.2", | |
# "mistralai/Mistral-7B-Instruct-v0.1", | |
"HuggingFaceH4/zephyr-7b-beta", | |
# "meta-llama/Llama-2-7b-chat-hf", | |
# "meta-llama/Llama-2-13b-chat-hf", | |
# "meta-llama/Llama-2-70b-chat-hf", | |
# "openchat/openchat-3.5-0106" | |
] | |
def init_model(model_id): | |
global tokenizer | |
global model | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.use_default_system_prompt = False | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda") | |
def auth_gcp(): | |
gcp_credentials = os.getenv("GCP_CREDENTIALS") | |
with open("gcp-credentials.json", "w") as f: | |
f.write(gcp_credentials) | |
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = './gcp-credentials.json' | |
google.auth.default() | |
def get_env_vars(): | |
global gcp_project_id, gcp_project_location | |
global hf_access_token | |
gcp_project_id = os.getenv("GCP_PROJECT_ID") | |
gcp_project_location = os.getenv("GCP_PROJECT_LOCATION") | |
hf_access_token = os.getenv("HF_TOKEN", None) |