|
|
|
|
|
import os |
|
|
import torch |
|
|
import spaces |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
PEFT_MODEL_ID = "befm/Be.FM-8B" |
|
|
|
|
|
|
|
|
CACHE_DIR = "/data" if os.path.exists("/data") else None |
|
|
|
|
|
USE_PEFT = True |
|
|
try: |
|
|
from peft import PeftModel, PeftConfig |
|
|
except Exception: |
|
|
USE_PEFT = False |
|
|
print("[WARN] 'peft' not installed; running base model only.") |
|
|
|
|
|
def load_model_and_tokenizer(): |
|
|
if HF_TOKEN is None: |
|
|
raise RuntimeError( |
|
|
"HF_TOKEN is not set. Add it in Space → Settings → Secrets. " |
|
|
"Also ensure your account has access to the gated base model." |
|
|
) |
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
tok = AutoTokenizer.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
cache_dir=CACHE_DIR |
|
|
) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
device_map="auto" if torch.cuda.is_available() else None, |
|
|
torch_dtype=dtype, |
|
|
token=HF_TOKEN, |
|
|
cache_dir=CACHE_DIR |
|
|
) |
|
|
|
|
|
print(f"[INFO] Using cache directory: {CACHE_DIR}") |
|
|
|
|
|
if USE_PEFT: |
|
|
try: |
|
|
_ = PeftConfig.from_pretrained( |
|
|
PEFT_MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
cache_dir=CACHE_DIR |
|
|
) |
|
|
model = PeftModel.from_pretrained( |
|
|
base, |
|
|
PEFT_MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
cache_dir=CACHE_DIR |
|
|
) |
|
|
print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}") |
|
|
return model, tok |
|
|
except Exception as e: |
|
|
print(f"[WARN] Failed to load PEFT adapter: {e}") |
|
|
return base, tok |
|
|
return base, tok |
|
|
|
|
|
|
|
|
_model = None |
|
|
_tokenizer = None |
|
|
|
|
|
def get_model_and_tokenizer(): |
|
|
global _model, _tokenizer |
|
|
if _model is None: |
|
|
_model, _tokenizer = load_model_and_tokenizer() |
|
|
return _model, _tokenizer |
|
|
|
|
|
@spaces.GPU |
|
|
@torch.inference_mode() |
|
|
def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str: |
|
|
model, tokenizer = get_model_and_tokenizer() |
|
|
device = model.device |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
enc = {k: v.to(device) for k, v in enc.items()} |
|
|
|
|
|
input_length = enc['input_ids'].shape[1] |
|
|
out = model.generate( |
|
|
**enc, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True) |
|
|
return generated_text.strip() |
|
|
|
|
|
def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature): |
|
|
|
|
|
messages = [] |
|
|
|
|
|
|
|
|
if not system_prompt: |
|
|
system_prompt = ( |
|
|
"Your are a Be.FM assistant. Be.FM is a family of open foundation models " |
|
|
"designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on " |
|
|
"diverse behavioral datasets, Be.FM models are designed to enhance the " |
|
|
"understanding and prediction of human decision-making." |
|
|
) |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
|
|
|
|
|
|
for msg in (history or []): |
|
|
role = msg.get("role", "user") |
|
|
content = msg.get("content", "") |
|
|
|
|
|
|
|
|
if isinstance(content, list): |
|
|
|
|
|
text_parts = [c.get("text", "") for c in content if c.get("type") == "text"] |
|
|
content = " ".join(text_parts) |
|
|
|
|
|
if content: |
|
|
messages.append({"role": role, "content": content}) |
|
|
|
|
|
if message: |
|
|
|
|
|
if isinstance(message, dict): |
|
|
text = message.get("text", "") |
|
|
else: |
|
|
text = message |
|
|
|
|
|
if text: |
|
|
messages.append({"role": "user", "content": text}) |
|
|
|
|
|
reply = generate_response( |
|
|
messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
) |
|
|
return reply |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=chat_fn, |
|
|
chatbot=gr.Chatbot( |
|
|
label="Chat with BeFM", |
|
|
show_label=True, |
|
|
avatar_images=(None, None), |
|
|
), |
|
|
additional_inputs=[ |
|
|
gr.Textbox( |
|
|
label="System prompt (optional)", |
|
|
placeholder=( |
|
|
"Your are a Be.FM assistant. Be.FM is a family of open foundation models " |
|
|
"designed for human behavior modeling. Built on Llama 3.1 and fine-" |
|
|
"tuned on diverse behavioral datasets, Be.FM models are designed to " |
|
|
"enhance the understanding and prediction of human decision-making." |
|
|
), |
|
|
lines=2, |
|
|
), |
|
|
gr.Markdown( |
|
|
"For system and user prompts in a variety of behavioral tasks, please refer " |
|
|
"to the appendix in our [paper](https://arxiv.org/abs/2505.23058)." |
|
|
), |
|
|
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), |
|
|
gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), |
|
|
], |
|
|
title="Be.FM: Open Foundation Models for Human Behavior (8B)", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |
|
|
|