gemma-2-9b-it1 / app.py
Leri777's picture
Update app.py
4df36c7 verified
raw
history blame
2.86 kB
import os
import logging
from logging.handlers import RotatingFileHandler
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
# Logging setup
log_file = '/tmp/app_debug.log'
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.debug("Application started")
model_id = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load model with GPU availability check
if torch.cuda.is_available():
logger.debug("GPU is available. Proceeding with GPU setup.")
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
else:
logger.warning("GPU is not available. Proceeding with CPU setup.")
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_auth_token=os.getenv('HF_TOKEN'),
)
model.eval()
# Create Hugging Face pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
)
# Initialize HuggingFacePipeline model for LangChain
chat_model = HuggingFacePipeline(pipeline=pipe)
logger.debug("Model and tokenizer loaded successfully")
# Define the conversation template for LangChain
template = """<|im_start|>system
{system_prompt}
<|im_end|>
{history}
<|im_start|>user
{human_input}
<|im_end|>
<|im_start|>assistant"""
# Create LangChain prompt and chain
prompt = PromptTemplate(
template=template, input_variables=["system_prompt", "history", "human_input"]
)
chain = LLMChain(llm=chat_model, prompt=prompt)
# Prediction function using LangChain and model
def predict(message, history=[]):
formatted_history = "\n".join(
[f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in history]
)
system_prompt = "You are a helpful coding assistant."
try:
result = chain.run({
"system_prompt": system_prompt,
"history": formatted_history,
"human_input": message
})
return result
except Exception as e:
logger.exception(f"Error during prediction: {e}")
return "An error occurred."
# Gradio UI
interface = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="User input"),
outputs="text",
allow_flagging='never',
live=True,
)
interface.launch()
logger.debug("Chat interface initialized and launched")