nyc-savvy-llama2-7b / pefttester.py
Nick Doiron
model stuff
6c09a1c
## Taken from the QLoRA Guanaco demo on Gradio
# https://github.com/artidoro/qlora
# https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import datetime
import os
from threading import Event, Thread
from uuid import uuid4
import requests
model_name = './nyc-savvy'
m = AutoModelForCausalLM.from_pretrained(model_name)
if 'llama' in model_name or 'savvy' in model_name:
tok = LlamaTokenizer.from_pretrained(model_name)
else:
tok = AutoTokenizer.from_pretrained(model_name)
tok.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
max_new_tokens = 1536
messages = "A chat between a curious human and an assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
messages += "### Human: What museums should I visit? - My kids are aged 12 and 5"
messages += "### Assistant: "
input_ids = tok(messages, return_tensors="pt").input_ids
input_ids = input_ids.to(m.device)
temperature = 0.7
top_p = 0.9
top_k = 0
repetition_penalty = 1.1
op = m.generate(
input_ids=input_ids,
max_new_tokens=100,
temperature=temperature,
do_sample=temperature > 0.0,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stopping_criteria=StoppingCriteriaList([stop]),
)
for line in op:
print(tok.decode(line))