## Taken from the QLoRA Guanaco demo on Gradio # https://github.com/artidoro/qlora # https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing import torch from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import datetime import os from threading import Event, Thread from uuid import uuid4 import requests model_name = './nyc-savvy' m = AutoModelForCausalLM.from_pretrained(model_name) if 'llama' in model_name or 'savvy' in model_name: tok = LlamaTokenizer.from_pretrained(model_name) else: tok = AutoTokenizer.from_pretrained(model_name) tok.bos_token_id = 1 stop_token_ids = [0] class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False stop = StopOnTokens() max_new_tokens = 1536 messages = "A chat between a curious human and an assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" messages += "### Human: What museums should I visit? - My kids are aged 12 and 5" messages += "### Assistant: " input_ids = tok(messages, return_tensors="pt").input_ids input_ids = input_ids.to(m.device) temperature = 0.7 top_p = 0.9 top_k = 0 repetition_penalty = 1.1 op = m.generate( input_ids=input_ids, max_new_tokens=100, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, stopping_criteria=StoppingCriteriaList([stop]), ) for line in op: print(tok.decode(line))