mouryachinta's picture
Update app.py
f927f46 verified
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
pipeline
)
import torch
from langchain.llms.base import LLM
from peft import LoraConfig
base_model_name = "mouryachinta/llama-2-7b-mourya"
# Tokenizer
llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right" # Fix for fp16
# Quantization Config
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False
)
# Model Initialization
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quant_config,
device_map={"": 0}
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1
# Define CustomLLM class
class CustomLLM(LLM):
def __init__(self):
super().__init__()
def _call(self, prompt, stop=None, run_manager=None) -> str:
inputs = llama_tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(base_model.device)
attention_mask = inputs.attention_mask.to(base_model.device) if "attention_mask" in inputs else None
if "max_length" not in stop:
stop["max_length"] = 20
result = base_model.generate(input_ids=input_ids, attention_mask=attention_mask, **stop)
result = llama_tokenizer.decode(result[0], skip_special_tokens=True)
return result
@property
def _llm_type(self) -> str:
return "custom"
# Instantiate CustomLLM
llm = CustomLLM()
from langchain import PromptTemplate
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate(template=template, input_variables=["question"])
from langchain import LLMChain
llm_chain = LLMChain(prompt=prompt, llm=llm)
import gradio as gr
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
llm_chain, llm = init_chain(model, tokenizer)
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
print("Question: ", history[-1][0])
bot_message = llm_chain.run(question=history[-1][0])
print("Response: ", bot_message)
history[-1][1] = ""
history[-1][1] += bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch()