SparkExpedition's picture
Update app.py
ae24dbd
raw
history blame contribute delete
No virus
3.01 kB
import os
import bitsandbytes as bnb
import pandas as pd
import torch
import torch.nn as nn
import transformers
from peft import (
LoraConfig,
PeftConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel
)
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
import gradio as gr
import warnings
warnings.filterwarnings("ignore")
device = "cuda:0"
MODEL_NAME = 'diegi97/dolly-v2-6.9b-sharded-bf16'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
load_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model =AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
trust_remote_code=True,
quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
peft_model_id = "AdiOO7/Azure-Classifier-dolly-7B"
# peft_model_id = "SparkExpedition/Ticket-Classifier-dolly-7B"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model, peft_model_id)
generation_config = model.generation_config
generation_config.max_new_tokens = 8
generation_config.num_return_sequences = 1
generation_config.temperature = 0.3
generation_config.top_p = 0.7
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
instruct = "From which azure service the issue is raised from {Power BI/Azure Data Factory/Azure Analysis Services}"
def generate_response(question: str) -> str:
prompt = f"""
### <instruction>: {instruct}
### <human>: {question}
### <assistant>:
""".strip()
encoding = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=generation_config,
)
response = tokenizer.decode(outputs[0],skip_special_tokens=True)
assistant_start = '<assistant>:'
response_start = response.find(assistant_start)
return response[response_start + len(assistant_start):].strip()
labels = ['PowerBI', 'Azure Data Factory', 'Azure Analysis Services']
def answer_prompt(prompt):
response = generate_response(prompt)
for lab in labels:
if response.find(lab) != -1:
return lab
iface = gr.Interface(fn=answer_prompt,
inputs=gr.Textbox(lines=5, label="Enter Your Issue", css={"font-size":"18px"}),
outputs=gr.Textbox(lines=5, label="Generated Answer", css={"font-size":"16px"}))
iface.launch()