File size: 1,928 Bytes
ff2d599
60093a0
d030e25
9641b31
8194866
679bcc5
 
e9cf7a6
058347f
3287b3a
 
 
 
2dce72a
3287b3a
 
679bcc5
0fa58d1
d2222b4
3287b3a
9641b31
0fa58d1
2d665bd
0fa58d1
20f6853
f92a739
 
0d61c55
c0b7aa4
f92a739
 
 
 
 
 
 
 
 
 
 
 
 
8ef6e3f
 
f92a739
3bf71d2
0d61c55
 
4cf26ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Device configuration (prioritize GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "Phearion/bigbrain-v0.0.1"

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )

# Load models and tokenizer efficiently
config = PeftConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=bnb_config)

# Load the Lora model
model = PeftModel.from_pretrained(model, model_id)

def greet(text):
    with torch.no_grad():
        # Include EOS token for better context
        input_text = f"<s>### User:\n{text}\n\n### Assistant:\n"
        
        batch = tokenizer(input_text, return_tensors='pt', add_special_tokens=True).to(device)

        with torch.cuda.amp.autocast():
            output_tokens = model.generate(
                **batch,
                max_new_tokens=25,  # Limit response length
                do_sample=True,  # Sample from the distribution
                pad_token_id=tokenizer.eos_token_id,  # Stop at EOS
            )

        # Decode only the generated tokens
        response = tokenizer.decode(output_tokens[0][len(batch['input_ids'][0]):], skip_special_tokens=True)

        # Additional stopping condition at next "### Response:"
        response_parts = response.split("### Assistant:")
        return response_parts[0]  # Return only the first part

iface = gr.Interface(fn=greet, inputs="text", outputs="text"
, title="PEFT Model for Big Brain")
iface.launch()  # Share directly to Gradio Space