import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # Constants for the Model MODEL_PATH = "SeaLLMs/SeaLLMs-v3-7B-Chat" MODEL_TITLE = "SeaLLMs Chat Model" MODEL_DESC = "A demo for the SeaLLMs-v3-7B-Chat language model." # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # Load the model with CPU offloading to reduce memory usage model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, device_map="auto", torch_dtype=torch.float32, # Use float32 to avoid numerical issues offload_folder="./offload", # Specify a folder for offloading to manage memory low_cpu_mem_usage=True ) # Enable gradient checkpointing for memory efficiency model.gradient_checkpointing_enable() def generate_response(prompt): # Limit the input length to prevent excessive memory usage prompt = prompt[:512] inputs = tokenizer(prompt, return_tensors="pt") # Move inputs to the same device as the model inputs = {key: value.to(model.device) for key, value in inputs.items()} # Generate response try: with torch.no_grad(): # Disable gradient calculation to save memory during inference outputs = model.generate( **inputs, max_length=128, # Further reduced max_length to lower memory usage num_return_sequences=1, no_repeat_ngram_size=2, early_stopping=True, temperature=0.7 # Adding temperature scaling to control output ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) except RuntimeError as e: # Handle numerical instability gracefully response = "An error occurred during generation. Please try again with a different prompt." print(f"RuntimeError: {e}") return response # Create the Gradio interface iface = gr.Interface( fn=generate_response, inputs=gr.Textbox(lines=5, label="Enter your message:"), outputs=gr.Textbox(label="Model's response:"), title=MODEL_TITLE, description=MODEL_DESC, theme="default" # You can specify any custom theme or remove this line ) if __name__ == "__main__": iface.launch()