model-inference / app.py
htigenai's picture
Update app.py
2153031 verified
raw
history blame
4.33 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
import sys
import gc
from contextlib import contextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Starting application...")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
try:
logger.info("Loading tokenizer...")
model_id = "htigenai/finetune_test_2"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=False # Use slow tokenizer to save memory
)
tokenizer.pad_token = tokenizer.eos_token
logger.info("Tokenizer loaded successfully")
logger.info("Loading model in 8-bit...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
load_in_8bit=True, # Load in 8-bit instead of 4-bit
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={0: "12GB", "cpu": "4GB"} # Limit memory usage
)
model.eval()
logger.info("Model loaded successfully in 8-bit")
# Clear any residual memory
gc.collect()
torch.cuda.empty_cache()
def generate_text(prompt, max_tokens=100, temperature=0.7):
try:
# Format the prompt
formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
# Generate with memory-efficient settings
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256 # Limit input length
).to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
repetition_penalty=1.2,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
no_repeat_ngram_size=3,
use_cache=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "### Assistant:" in response:
response = response.split("### Assistant:")[-1].strip()
# Clean up memory after generation
del outputs, inputs
gc.collect()
torch.cuda.empty_cache()
return response
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return f"Error generating response: {str(e)}"
# Create a more memory-efficient Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
lines=3,
placeholder="Enter your prompt here...",
label="Input Prompt",
max_lines=5
),
gr.Slider(
minimum=10,
maximum=100,
value=50,
step=10,
label="Max Tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
)
],
outputs=gr.Textbox(
label="Generated Response",
lines=5
),
title="HTIGENAI Reflection Analyzer (8-bit)",
description="8-bit quantized text generation. Please keep prompts concise for best results.",
examples=[
["What is machine learning?", 50, 0.7],
["Explain quantum computing", 50, 0.7],
],
cache_examples=False
)
# Launch with minimal memory usage
iface.launch(
server_name="0.0.0.0",
share=False,
show_error=True,
enable_queue=True,
max_threads=1
)
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
raise