File size: 1,928 Bytes
ff2d599 60093a0 d030e25 9641b31 8194866 679bcc5 e9cf7a6 058347f 3287b3a 2dce72a 3287b3a 679bcc5 0fa58d1 d2222b4 3287b3a 9641b31 0fa58d1 2d665bd 0fa58d1 20f6853 f92a739 0d61c55 c0b7aa4 f92a739 8ef6e3f f92a739 3bf71d2 0d61c55 4cf26ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# Device configuration (prioritize GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "Phearion/bigbrain-v0.0.1"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Load models and tokenizer efficiently
config = PeftConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=bnb_config)
# Load the Lora model
model = PeftModel.from_pretrained(model, model_id)
def greet(text):
with torch.no_grad():
# Include EOS token for better context
input_text = f"<s>### User:\n{text}\n\n### Assistant:\n"
batch = tokenizer(input_text, return_tensors='pt', add_special_tokens=True).to(device)
with torch.cuda.amp.autocast():
output_tokens = model.generate(
**batch,
max_new_tokens=25, # Limit response length
do_sample=True, # Sample from the distribution
pad_token_id=tokenizer.eos_token_id, # Stop at EOS
)
# Decode only the generated tokens
response = tokenizer.decode(output_tokens[0][len(batch['input_ids'][0]):], skip_special_tokens=True)
# Additional stopping condition at next "### Response:"
response_parts = response.split("### Assistant:")
return response_parts[0] # Return only the first part
iface = gr.Interface(fn=greet, inputs="text", outputs="text"
, title="PEFT Model for Big Brain")
iface.launch() # Share directly to Gradio Space |