BeFM / app.py
Jn-Huang
Update default temperature from 0.6 to 0.7
49689a5
# app.py
import os
import torch
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
PEFT_MODEL_ID = "befm/Be.FM-8B"
# Use /data for persistent storage to avoid re-downloading models
CACHE_DIR = "/data" if os.path.exists("/data") else None
USE_PEFT = True
try:
from peft import PeftModel, PeftConfig # noqa
except Exception:
USE_PEFT = False
print("[WARN] 'peft' not installed; running base model only.")
def load_model_and_tokenizer():
if HF_TOKEN is None:
raise RuntimeError(
"HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
"Also ensure your account has access to the gated base model."
)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tok = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=dtype,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
print(f"[INFO] Using cache directory: {CACHE_DIR}")
if USE_PEFT:
try:
_ = PeftConfig.from_pretrained(
PEFT_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
model = PeftModel.from_pretrained(
base,
PEFT_MODEL_ID,
token=HF_TOKEN,
cache_dir=CACHE_DIR # Use persistent storage
)
print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
return model, tok
except Exception as e:
print(f"[WARN] Failed to load PEFT adapter: {e}")
return base, tok
return base, tok
# Lazy load model and tokenizer
_model = None
_tokenizer = None
def get_model_and_tokenizer():
global _model, _tokenizer
if _model is None:
_model, _tokenizer = load_model_and_tokenizer()
return _model, _tokenizer
@spaces.GPU
@torch.inference_mode()
def generate_response(messages, max_new_tokens=512, temperature=0.7) -> str:
model, tokenizer = get_model_and_tokenizer()
device = model.device
# Apply Llama 3.1 chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
enc = {k: v.to(device) for k, v in enc.items()}
input_length = enc['input_ids'].shape[1]
out = model.generate(
**enc,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the newly generated tokens
generated_text = tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
return generated_text.strip()
def chat_fn(message, history, system_prompt, _prompt_reference, max_new_tokens, temperature):
# Build conversation in Llama 3.1 chat format
messages = []
# Add system prompt (use default if not provided)
if not system_prompt:
system_prompt = (
"Your are a Be.FM assistant. Be.FM is a family of open foundation models "
"designed for human behavior modeling. Built on Llama 3.1 and fine-tuned on "
"diverse behavioral datasets, Be.FM models are designed to enhance the "
"understanding and prediction of human decision-making."
)
messages.append({"role": "system", "content": system_prompt})
# Handle Gradio 6.0 history format
# History format: [{"role": "user", "content": [{"type": "text", "text": "..."}]}, ...]
for msg in (history or []):
role = msg.get("role", "user")
content = msg.get("content", "")
# Extract text from structured content
if isinstance(content, list):
# Gradio 6.0 format: content is a list of dicts
text_parts = [c.get("text", "") for c in content if c.get("type") == "text"]
content = " ".join(text_parts)
if content:
messages.append({"role": role, "content": content})
if message:
# Handle message (could be string or dict in Gradio 6.0)
if isinstance(message, dict):
text = message.get("text", "")
else:
text = message
if text:
messages.append({"role": "user", "content": text})
reply = generate_response(
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
return reply
demo = gr.ChatInterface(
fn=chat_fn,
chatbot=gr.Chatbot(
label="Chat with BeFM",
show_label=True,
avatar_images=(None, None), # Use default avatars or provide custom image paths
),
additional_inputs=[
gr.Textbox(
label="System prompt (optional)",
placeholder=(
"Your are a Be.FM assistant. Be.FM is a family of open foundation models "
"designed for human behavior modeling. Built on Llama 3.1 and fine-"
"tuned on diverse behavioral datasets, Be.FM models are designed to "
"enhance the understanding and prediction of human decision-making."
),
lines=2,
),
gr.Markdown(
"For system and user prompts in a variety of behavioral tasks, please refer "
"to the appendix in our [paper](https://arxiv.org/abs/2505.23058)."
),
gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
],
title="Be.FM: Open Foundation Models for Human Behavior (8B)",
)
if __name__ == "__main__":
demo.launch(share=True)