File size: 1,755 Bytes
b49b83b
d8b2749
11a35d1
b49b83b
11a35d1
e622ac4
b49b83b
11a35d1
46d9167
 
e622ac4
11a35d1
 
46d9167
e622ac4
b49b83b
11a35d1
 
3ea1454
11a35d1
 
46d9167
11a35d1
e622ac4
b49b83b
e622ac4
 
 
b49b83b
e622ac4
 
3ea1454
 
 
 
 
 
e622ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b49b83b
 
e622ac4
 
 
 
 
 
 
b49b83b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr

# Model and device setup
base_model_name = "unsloth/gemma-3-12b-it-unsloth-bnb-4bit"
adapter_name = "adarsh3601/my_gemma3_pt"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model with 4-bit quantization
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map={"": device},
    torch_dtype=torch.float16,  # Keep float16 unless it breaks
    load_in_4bit=True
)

# Load tokenizer and adapter
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(base_model, adapter_name)
model.to(device)

# Chat function with stability safeguards
def chat(message):
    if not message or not message.strip():
        return "Please enter a valid message."

    inputs = tokenizer(message, return_tensors="pt")

    # Safely move to device; only convert float tensors to half
    for k in inputs:
        if inputs[k].dtype == torch.float32:
            inputs[k] = inputs[k].to(device).half()
        else:
            inputs[k] = inputs[k].to(device)

    try:
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=150,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                temperature=0.8
            )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response

    except RuntimeError as e:
        return f"An error occurred during generation: {str(e)}"

# Launch Gradio app
iface = gr.Interface(
    fn=chat,
    inputs="text",
    outputs="text",
    title="Gemma Chatbot"
)

iface.launch()