matthew chung
Update app.py
91edf12
raw
history blame
1.21 kB
import streamlit as st
import torch
import transformers
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from datasets import Dataset
import pandas as pd
import transformers
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import time
peft_model_id = "foobar8675/bloom-7b1-lora-tagger"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-7b1", return_dict=True, load_in_8bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
# # Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)
text = st.text_area('enter text in this format : β€œ<<report>>” ->: ')
if text:
start_time = time.time()
batch = tokenizer(text, return_tensors='pt')
output_tokens = model.generate(**batch, max_new_tokens=25)
out = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
st.json(out)
st.json(f"Elapsed time: {time.time() - start_time}s")