import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline ) import torch from langchain.llms.base import LLM from peft import LoraConfig base_model_name = "mouryachinta/llama-2-7b-mourya" # Tokenizer llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) llama_tokenizer.pad_token = llama_tokenizer.eos_token llama_tokenizer.padding_side = "right" # Fix for fp16 # Quantization Config quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=False ) # Model Initialization base_model = AutoModelForCausalLM.from_pretrained( base_model_name, quantization_config=quant_config, device_map={"": 0} ) base_model.config.use_cache = False base_model.config.pretraining_tp = 1 # Define CustomLLM class class CustomLLM(LLM): def __init__(self): super().__init__() def _call(self, prompt, stop=None, run_manager=None) -> str: inputs = llama_tokenizer(prompt, return_tensors="pt") input_ids = inputs.input_ids.to(base_model.device) attention_mask = inputs.attention_mask.to(base_model.device) if "attention_mask" in inputs else None if "max_length" not in stop: stop["max_length"] = 20 result = base_model.generate(input_ids=input_ids, attention_mask=attention_mask, **stop) result = llama_tokenizer.decode(result[0], skip_special_tokens=True) return result @property def _llm_type(self) -> str: return "custom" # Instantiate CustomLLM llm = CustomLLM() from langchain import PromptTemplate template = """Question: {question} Answer: Let's think step by step.""" prompt = PromptTemplate(template=template, input_variables=["question"]) from langchain import LLMChain llm_chain = LLMChain(prompt=prompt, llm=llm) import gradio as gr with gr.Blocks() as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") llm_chain, llm = init_chain(model, tokenizer) def user(user_message, history): return "", history + [[user_message, None]] def bot(history): print("Question: ", history[-1][0]) bot_message = llm_chain.run(question=history[-1][0]) print("Response: ", bot_message) history[-1][1] = "" history[-1][1] += bot_message return history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot) clear.click(lambda: None, None, chatbot, queue=False) demo.queue() demo.launch()