import json import os.path as osp import random from typing import Union import os import sys from typing import List import torch import transformers from datasets import load_dataset from transformers import AutoModelForCausalLM, TrainingArguments, Trainer import gradio as gr import torch.nn as nn from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, set_peft_model_state_dict, PeftModel ) from transformers import LlamaForCausalLM, LlamaTokenizer base_model='nickypro/tinyllama-15M' class Prompter(object): def generate_prompt( self, instruction: str, label: Union[None, str] = None, ) -> str: res = f"{instruction}\nAnswer: " if label: res = f"{res}{label}" return res def get_response(self, output: str) -> str: return output.split("Answer:")[1].strip().replace("/", "\u00F7").replace("*", "\u00D7") model = LlamaForCausalLM.from_pretrained( base_model, torch_dtype=torch.float32, device_map="auto", ) model = PeftModel.from_pretrained( model, f'checkpoint-16000', torch_dtype=torch.float32, ) model.eval() if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" def generate_answers(instructions, model, tokenizer): prompter = Prompter() raw_answers = [] for instruction in instructions: prompt = prompter.generate_prompt(instruction) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"] generation_output = model.generate( input_ids=input_ids, return_dict_in_generate=True, output_scores=True, pad_token_id=0, eos_token_id=tokenizer.eos_token_id, max_new_tokens=16 ) s = generation_output.sequences[0] raw_answers.append(tokenizer.decode(s, skip_special_tokens=True).strip()) return raw_answers def evaluate(instruction): return generate_answers([instruction], model, tokenizer)[0] if __name__ == "__main__": gr.Interface( fn=evaluate, inputs=[ gr.components.Textbox( lines=1, label="Arithmetic", placeholder="63303235 + 20239503", ) ], outputs=[ gr.Textbox( lines=5, label="Output", ) ], title="Arithmetic LLaMA", description="This model is 15M llama model, finetuned on a+b tasks", ).queue().launch(server_name="0.0.0.0", share=True)