chansung's picture
Update vid2persona/init.py
d375f40 verified
import os
import torch
import google.auth
from transformers import AutoModelForCausalLM, AutoTokenizer
# https://huggingface.co/blog/inference-pro
ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS = [
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
# "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
# "mistralai/Mistral-7B-Instruct-v0.2",
# "mistralai/Mistral-7B-Instruct-v0.1",
"HuggingFaceH4/zephyr-7b-beta",
# "meta-llama/Llama-2-7b-chat-hf",
# "meta-llama/Llama-2-13b-chat-hf",
# "meta-llama/Llama-2-70b-chat-hf",
# "openchat/openchat-3.5-0106"
]
def init_model(model_id):
global tokenizer
global model
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda")
def auth_gcp():
gcp_credentials = os.getenv("GCP_CREDENTIALS")
with open("gcp-credentials.json", "w") as f:
f.write(gcp_credentials)
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = './gcp-credentials.json'
google.auth.default()
def get_env_vars():
global gcp_project_id, gcp_project_location
global hf_access_token
gcp_project_id = os.getenv("GCP_PROJECT_ID")
gcp_project_location = os.getenv("GCP_PROJECT_LOCATION")
hf_access_token = os.getenv("HF_TOKEN", None)