Spaces:
Running
Running
File size: 4,776 Bytes
03746f8 eba8d8b 03746f8 eba8d8b 03746f8 e9e49b2 03746f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch
from peft import PeftModel, PeftConfig
import transformers
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM, GenerationConfig
from transformers.models.opt.modeling_opt import OPTDecoderLayer
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom')
BASE_MODEL = "bigscience/bloom-3b"
LORA_WEIGHTS = "jslin09/LegalChatbot-bloom-3b"
config = PeftConfig.from_pretrained(LORA_WEIGHTS)
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except:
pass
if device == "cuda":
model = BloomForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
elif device == "mps":
model = BloomForCausalLM.from_pretrained(
BASE_MODEL,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
LORA_WEIGHTS,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = BloomForCausalLM.from_pretrained(
BASE_MODEL, device_map={"": device},
low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
LORA_WEIGHTS,
device_map={"": device},
)
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:"""
def generate_prompt_tw(instruction, input=None):
if input:
return f"""以下是描述任務的指令,並與提供進一步上下文的輸入配對。編寫適當完成請求的回應。
### 指令:
{instruction}
### 輸入:
{input}
### 回應:"""
else:
return f"""以下是描述任務的指令。編寫適當完成請求的回應。
### 指令:
{instruction}
### 回應:"""
model.eval()
if torch.__version__ >= "2":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=128,
**kwargs,
):
prompt = generate_prompt(instruction, input) # 中文版的話,函數名稱要改用 generate_prompt_tw
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
do_sample=True,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
# return output.split("### Response:")[1].strip() # 中文版的話,要改為 return output.split("### 回應:")[1].strip()
return output.split("### 回應:")[1].strip()
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2, label="Instruction", placeholder="Tell me about alpacas."
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="🌲 🌲 🌲 BLOOM-LoRA-LegalChatbot",
description="BLOOM-LoRA-LegalChatbot is a 3B-parameter BLOOM model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and my Legal QA dataset, and makes use of the Huggingface BLOOM implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
).launch() |