TestingModelAPI / app.py
made1570's picture
Update app.py
22dbba3 verified
raw
history blame
2.69 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
# Set the environment variable for debugging (you can remove this in production)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Load model and tokenizer
base_model_name = "adarsh3601/my_gemma_pt3"
adapter_name = "your_adapter_name_here" # Replace with actual adapter name if needed
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto", # Using device_map="auto" for automatic GPU assignment
torch_dtype=torch.float32, # Switch to float32 to avoid precision issues
load_in_4bit=True # This should still be set if your model supports it
)
# Load the adapter model
model = PeftModel.from_pretrained(base_model, adapter_name)
model.to(device)
# Ensure the model is in evaluation mode
model.eval()
# Chat function with added input/output validation
def chat(message):
# Tokenize input message
inputs = tokenizer(message, return_tensors="pt")
# Check if any input token contains NaN or Inf
if torch.any(torch.isnan(inputs['input_ids'])) or torch.any(torch.isinf(inputs['input_ids'])):
return "Input contains invalid values (NaN or Inf). Please check the input."
# Move tensors to the correct device
inputs = {k: v.to(device).half() for k, v in inputs.items()} # Using half precision for performance
try:
# Generate response
outputs = model.generate(**inputs, max_new_tokens=150, do_sample=True)
# Check for NaNs or Infs in the output
if torch.any(torch.isnan(outputs)) or torch.any(torch.isinf(outputs)):
return "Model output contains invalid values (NaN or Inf). Please try again."
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
# Catch any errors that occur during generation and return them
response = f"Unexpected error: {str(e)}"
return response
# Gradio interface for the chat
import gradio as gr
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("## Chat with Gemma Model")
with gr.Row():
message_input = gr.Textbox(label="Input Message")
output = gr.Textbox(label="Model Response")
# Button to trigger the chat
button = gr.Button("Generate Response")
button.click(fn=chat, inputs=message_input, outputs=output)
demo.launch()
if __name__ == "__main__":
gradio_interface()