Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import logging | |
from logging.handlers import RotatingFileHandler | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain_huggingface import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
# Logging setup | |
log_file = '/tmp/app_debug.log' | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5) | |
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) | |
logger.addHandler(file_handler) | |
logger.debug("Application started") | |
model_id = "google/gemma-2-9b-it" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# Load model with GPU availability check | |
if torch.cuda.is_available(): | |
logger.debug("GPU is available. Proceeding with GPU setup.") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
else: | |
logger.warning("GPU is not available. Proceeding with CPU setup.") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
low_cpu_mem_usage=True, | |
use_auth_token=os.getenv('HF_TOKEN'), | |
) | |
model.eval() | |
# Create Hugging Face pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=2048, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
) | |
# Initialize HuggingFacePipeline model for LangChain | |
chat_model = HuggingFacePipeline(pipeline=pipe) | |
logger.debug("Model and tokenizer loaded successfully") | |
# Define the conversation template for LangChain | |
template = """<|im_start|>system | |
{system_prompt} | |
<|im_end|> | |
{history} | |
<|im_start|>user | |
{human_input} | |
<|im_end|> | |
<|im_start|>assistant""" | |
# Create LangChain prompt and chain | |
prompt = PromptTemplate( | |
template=template, input_variables=["system_prompt", "history", "human_input"] | |
) | |
chain = LLMChain(llm=chat_model, prompt=prompt) | |
# Prediction function using LangChain and model | |
def predict(message, history=[]): | |
formatted_history = "\n".join( | |
[f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in history] | |
) | |
system_prompt = "You are a helpful coding assistant." | |
try: | |
result = chain.run({ | |
"system_prompt": system_prompt, | |
"history": formatted_history, | |
"human_input": message | |
}) | |
return result | |
except Exception as e: | |
logger.exception(f"Error during prediction: {e}") | |
return "An error occurred." | |
# Gradio UI | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(label="User input"), | |
outputs="text", | |
allow_flagging='never', | |
live=True, | |
) | |
interface.launch() | |
logger.debug("Chat interface initialized and launched") | |