import os import logging from logging.handlers import RotatingFileHandler import gradio as gr import torch from accelerate import Accelerator from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline from langchain_huggingface import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain.chains import LLMChain # Logging setup log_file = '/tmp/app_debug.log' logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5) file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(file_handler) logger.debug("Application started") model_id = "google/gemma-2-9b-it" tokenizer = GemmaTokenizerFast.from_pretrained(model_id) # Load model with GPU availability check if torch.cuda.is_available(): logger.debug("GPU is available. Proceeding with GPU setup.") model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) else: logger.warning("GPU is not available. Proceeding with CPU setup.") model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", low_cpu_mem_usage=True, use_auth_token=True, ) model.eval() # Create Hugging Face pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_length=2048, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.2, ) # Initialize HuggingFacePipeline model for LangChain chat_model = HuggingFacePipeline(pipeline=pipe) # Define the conversation template for LangChain template = """<|im_start|>system {system_prompt} <|im_end|> {history} <|im_start|>user {human_input} <|im_end|> <|im_start|>assistant""" # Create LangChain prompt and chain prompt = PromptTemplate( template=template, input_variables=["system_prompt", "history", "human_input"] ) chain = LLMChain(llm=chat_model, prompt=prompt) # Prediction function using LangChain and model def predict(message, chat_history): formatted_history = "\n".join( [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history] ) system_prompt = "You are a helpful coding assistant." try: result = chain.run( { "system_prompt": system_prompt, "history": formatted_history, "human_input": message, } ) return result except Exception as e: logger.exception(f"Error during prediction: {e}") return "An error occurred." # Gradio UI interface = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="User input"), gr.State(), ], outputs="text", live=True, ) interface.launch() logger.debug("Chat interface initialized and launched")