Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
from string import Template | |
import streamlit as st | |
import base64 | |
from datasets import load_dataset | |
from datasets import Dataset | |
import torch | |
from tqdm import tqdm | |
from peft import LoraConfig, get_peft_model | |
import transformers | |
# from transformers import AutoModelForCausalLM, AdapterConfig | |
from transformers import AutoConfig,AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer | |
from transformers import TrainingArguments | |
from peft import LoraConfig | |
from peft import * | |
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | |
from langchain.prompts import PromptTemplate | |
from IPython.display import Markdown, display | |
peft_model_id = "./" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = PeftModel.from_pretrained(model, peft_model_id) | |
prompt_template = """Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D, E] \ | |
in order of the most likely to be correct to the least likely to be correct.' | |
Question: {prompt}\n | |
A) {a}\n | |
B) {b}\n | |
C) {c}\n | |
D) {d}\n | |
E) {e}\n | |
Answer: """ | |
prompt = PromptTemplate(template=prompt_template, input_variables=['prompt', 'a', 'b', 'c', 'd', 'e']) | |
def format_text_to_prompt(example): | |
ans = prompt.format(prompt=example['prompt'], | |
a=example['A'], | |
b=example['B'], | |
c=example['C'], | |
d=example['D'], | |
e=example['E']) | |
return {"ans": ans} | |
def get_ans(text): | |
inputs = tokenizer(text, return_tensors='pt') | |
logits = model(input_ids=inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda()).logits[0, -1] | |
# Create a list of tuples having (logit, 'option') format | |
options_list = [(logits[tokenizer(' A').input_ids[-1]], 'A'), (logits[tokenizer(' B').input_ids[-1]], 'B'), (logits[tokenizer(' C').input_ids[-1]], 'C'), (logits[tokenizer(' D').input_ids[-1]], 'D'), (logits[tokenizer(' E').input_ids[-1]], 'E')] | |
options_list = sorted(options_list, reverse=True) | |
ans_list = [] | |
for i in range(3): | |
ans_list.append(options_list[i][1]) | |
return ans_list | |
def get_base64_of_bin_file(bin_file): | |
with open(bin_file, 'rb') as f: | |
data = f.read() | |
return base64.b64encode(data).decode() | |
def set_png_as_page_bg(png_file): | |
img = get_base64_of_bin_file(png_file) | |
page_bg_img = f""" | |
<style> | |
[data-testid="stAppViewContainer"] > .main {{ | |
background-image: url("https://www.tata.com/content/dam/tata/images/verticals/desktop/banner_travel_umaidbhavan_desktop_1920x1080.jpg"); | |
background-size: 200%; | |
background-position: center; | |
background-repeat: no-repeat; | |
background-attachment: local; | |
}} | |
[data-testid="stSidebar"] > div:first-child {{ | |
background-image: url("data:image/png;base64,{img}"); | |
background-position: center; | |
background-repeat: no-repeat; | |
background-attachment: fixed; | |
}} | |
[data-testid="stHeader"] {{ | |
background: rgba(0,0,0,0); | |
}} | |
[data-testid="stToolbar"] {{ | |
right: 2rem; | |
}} | |
</style> | |
""" | |
st.markdown(page_bg_img, unsafe_allow_html=True) | |
def get_base64_encoded_image(image_path): | |
with open(image_path, "rb") as img_file: | |
encoded_string = base64.b64encode(img_file.read()).decode("utf-8") | |
return encoded_string | |
def main(): | |
set_png_as_page_bg("net_technology_5407.jpg") | |
image_path = "artificial-intelligence.jpg" # Replace with the actual image file path | |
st.title("Sci-mcq-GPT") | |
link = "https://drive.google.com/file/d/1_2TqNNyoczhxIBmU7BpOzEi2bu3MC-sx/view?usp=sharing" | |
icon_path = "pdf download logo.png" | |
encoded_image = get_base64_encoded_image(icon_path) | |
lnk = f'<a href="{link}"><img src="data:image/png;base64,{encoded_image}" width="50" height="50"></a>' | |
col = st.sidebar | |
col.markdown(lnk, unsafe_allow_html=True) | |
st.subheader("Ask Q&A") | |
col1, col2 = st.columns(2) | |
query = col1.text_area("Enter your question") | |
if col1.button("Get Answer"): | |
ans = get_ans(query) | |
print(ans) | |
col2.text_area("Sci-mcq-GPT Response", ans) | |
else: | |
col2.text_area("Sci-mcq-GPT Response", value="") | |
col_sidebar = st.sidebar | |
col_sidebar.image(image_path, caption=" ", width=300) | |
if __name__ == "__main__": | |
main() |