Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import torch | |
import transformers | |
from peft import PeftModel | |
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer | |
def main( | |
load_8bit: bool = False, | |
base_model: str = "decapoda-research/llama-7b-hf", | |
lora_weights: str = "ohmreborn/llama-lora-7b", | |
): | |
device = 'cpu' | |
base_model = base_model | |
tokenizer = LlamaTokenizer.from_pretrained(base_model) | |
model = LlamaForCausalLM.from_pretrained( | |
base_model, | |
load_in_8bit=load_8bit, | |
max_memory={"cpu":"15GiB"}, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_weights, | |
device_map={"": device}, | |
) | |
model.config.pad_token_id = tokenizer.pad_token_id = 0 | |
model.config.bos_token_id = 1 | |
model.config.eos_token_id = 2 | |
model.eval() | |
if torch.__version__ >= "2" and sys.platform != "win32": | |
model = torch.compile(model) | |
return model,tokenizer | |
model,tokenizer = main() | |
from typing import Union | |
import requests | |
class Prompter(object): | |
def __init__(self): | |
url = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/templates/alpaca.json" | |
response = requests.request("GET", url) | |
self.template = response.json() | |
def generate_prompt( | |
self, | |
instruction: str, | |
input: Union[None, str] = None, | |
label: Union[None, str] = None, | |
) -> str: | |
if input: | |
res = self.template["prompt_input"].format( | |
instruction=instruction, input=input | |
) | |
else: | |
res = self.template["prompt_no_input"].format( | |
instruction=instruction | |
) | |
if label: | |
res = f"{res}{label}" | |
return res | |
def get_response(self, output: str) -> str: | |
return output.split(self.template["response_split"])[1].strip() | |
def generate( | |
input=None, | |
temperature=0.75, # ทำให้ model มั่นใจมากขึ้นใน softmax function https://stackoverflow.com/questions/58764619/why-should-we-use-temperature-in-softmax/63471046#63471046 | |
top_p=0.95, # จะ เอา ค่าความน่าจะเป็นของ top ความน่าจะเป็นที่มากที่สุดมารวมกันจนมากกว่า 0.95 แล้วค่อยให้ model สุ่ม ออกมาhttps://www.linkedin.com/pulse/text-generation-temperature-top-p-sampling-gpt-models-selvakumar | |
top_k=50, # เอา 50 แรก แต่ถ้า ใส่ค่า top p ไปด้วย จะทำให้ คิดของ top p ก่อน เช่น ถ้า 50 ตัวแรกมีความน่าจะเป็นรวมกัน = 0.90 ซึ่งไม่ถึงค่าที่ตั้งไว้ก็เอามาไว้ใช้สำหรับการทำนายครั้งถัดไป https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p#2-pick-from-amongst-the-top-tokens-top-k | |
max_new_tokens=1024, | |
instruction="Please create an inference question in the style of TOEFL reading comprehension section. Also provide an answer in the format", | |
model=model, | |
tokenizer=tokenizer, | |
): | |
prompter = Prompter() | |
prompt = prompter.generate_prompt(instruction, input,) | |
print(prompt) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=1.2 | |
) | |
with torch.no_grad(): | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=max_new_tokens, | |
) | |
s = generation_output.sequences[0] | |
output = tokenizer.decode(s) | |
return prompter.get_response(output) | |
import gradio as gr | |
example = """Education is the process of facilitating learning, or the acquisition of knowledge, skills, values, morals, beliefs, habits, | |
and personal development. There are many types of potential educational aims and objectives, | |
irrespective of the specific subject being learned. Some can cross multiple school disciplines. | |
""" | |
demo = gr.Interface(fn=generate, | |
inputs=[gr.Textbox(value=example,label='inputs'), | |
gr.Slider(0,1,value=0.75,step=0.05,label='temperature'), | |
gr.Slider(0,1,value=0.95,step=0.05,label='top_p'), | |
gr.Slider(0,100,value=50,step=10,label='top_k')], | |
outputs=["text"]) | |
demo.launch(inline=False) |