import streamlit as st
import uuid
import sys
import requests
from peft import *
import bitsandbytes as bnb
import pandas as pd
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from huggingface_hub import notebook_login
from peft import (
LoraConfig,
PeftConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
USER_ICON = "images/user-icon.png"
AI_ICON = "images/ai-icon.png"
MAX_HISTORY_LENGTH = 5
if 'user_id' in st.session_state:
user_id = st.session_state['user_id']
else:
user_id = str(uuid.uuid4())
st.session_state['user_id'] = user_id
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if "chats" not in st.session_state:
st.session_state.chats = [
{
'id': 0,
'question': '',
'answer': ''
}
]
if "questions" not in st.session_state:
st.session_state.questions = []
if "answers" not in st.session_state:
st.session_state.answers = []
if "input" not in st.session_state:
st.session_state.input = ""
st.markdown("""
""", unsafe_allow_html=True)
def write_top_bar():
col1, col2, col3 = st.columns([1,10,2])
with col1:
st.image(AI_ICON, use_column_width='always')
with col2:
header = "Cogwise Intelligent Assistant"
st.write(f"
{header}
", unsafe_allow_html=True)
with col3:
clear = st.button("Clear Chat")
return clear
clear = write_top_bar()
if clear:
st.session_state.questions = []
st.session_state.answers = []
st.session_state.input = ""
st.session_state["chat_history"] = []
def handle_input():
input = st.session_state.input
question_with_id = {
'question': input,
'id': len(st.session_state.questions)
}
st.session_state.questions.append(question_with_id)
chat_history = st.session_state["chat_history"]
if len(chat_history) == MAX_HISTORY_LENGTH:
chat_history = chat_history[:-1]
# api_url = "https://9pl792yjf9.execute-api.us-east-1.amazonaws.com/beta/chatcogwise"
# api_request_data = {"question": input, "session": user_id}
# api_response = requests.post(api_url, json=api_request_data)
# result = api_response.json()
# answer = result['answer']
# !pip install -Uqqq pip --progress-bar off
# !pip install -qqq bitsandbytes == 0.39.0
# !pip install -qqqtorch --2.0.1 --progress-bar off
# !pip install -qqq -U git + https://github.com/huggingface/transformers.git@e03a9cc --progress-bar off
# !pip install -qqq -U git + https://github.com/huggingface/peft.git@42a184f --progress-bar off
# !pip install -qqq -U git + https://github.com/huggingface/accelerate.git@c9fbb71 --progress-bar off
# !pip install -qqq datasets == 2.12.0 --progress-bar off
# !pip install -qqq loralib == 0.1.1 --progress-bar off
# !pip install einops
import os
# from pprint import pprint
# import json
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# notebook_login()
# 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
# """### Load dataset"""
from datasets import load_dataset
dataset_name = "nisaar/Lawyer_GPT_India"
# dataset_name = "patrick11434/TEST_LLM_DATASET"
dataset = load_dataset(dataset_name, split="train")
# """## Load adapters from the Hub
# You can also directly load adapters from the Hub using the commands below:
# """
# change peft_model_id
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
load_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model, peft_model_id)
"""## Inference
You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`.
"""
generation_config = model.generation_config
generation_config.max_new_tokens = 200
generation_config_temperature = 1
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config_eod_token_id = tokenizer.eos_token_id
DEVICE = "cuda:0"
# Commented out IPython magic to ensure Python compatibility.
# %%time
# prompt = f"""
# : Who appoints the Chief Justice of India?
# :
# """.strip()
#
# encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
# with torch.inference_mode():
# outputs = model.generate(
# input_ids=encoding.attention_mask,
# generation_config=generation_config,
# )
# print(tokenizer.decode(outputs[0],skip_special_tokens=True))
def generate_response(question: str) -> str:
prompt = f"""
: {question}
:
""".strip()
encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=generation_config,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
assistant_start = ':'
response_start = response.find(assistant_start)
return response[response_start + len(assistant_start):].strip()
# prompt = "Debate the merits and demerits of introducing simultaneous elections in India?"
prompt=input
answer=print(generate_response(prompt))
# answer='Yes'
chat_history.append((input, answer))
st.session_state.answers.append({
'answer': answer,
'id': len(st.session_state.questions)
})
st.session_state.input = ""
def write_user_message(md):
col1, col2 = st.columns([1,12])
with col1:
st.image(USER_ICON, use_column_width='always')
with col2:
st.warning(md['question'])
def render_answer(answer):
col1, col2 = st.columns([1,12])
with col1:
st.image(AI_ICON, use_column_width='always')
with col2:
st.info(answer)
def write_chat_message(md, q):
chat = st.container()
with chat:
render_answer(md['answer'])
with st.container():
for (q, a) in zip(st.session_state.questions, st.session_state.answers):
write_user_message(q)
write_chat_message(a, q)
st.markdown('---')
input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)