Spaces:
Sleeping
Sleeping
## Adapted from the QLoRA Guanaco demo on Gradio | |
# https://github.com/artidoro/qlora | |
# https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
model_name = 'monsoon-nlp/nyc-savvy-llama2-7b' | |
m = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
tok = LlamaTokenizer.from_pretrained(model_name) | |
tok.bos_token_id = 1 | |
stop_token_ids = [0] | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
for stop_id in stop_token_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
stop = StopOnTokens() | |
max_new_tokens = 1536 | |
def query(prompt, chatStart): | |
messages = prompt.strip() + "\n" | |
messages += "### Human: " + chatStart.strip() | |
messages += "### Assistant: " | |
input_ids = tok(messages, return_tensors="pt").input_ids | |
input_ids = input_ids.to(m.device) | |
temperature = 0.7 | |
top_p = 0.9 | |
top_k = 0 | |
repetition_penalty = 1.1 | |
op = m.generate( | |
input_ids=input_ids, | |
max_new_tokens=100, | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
stopping_criteria=StoppingCriteriaList([stop]), | |
) | |
outputtxt = "" | |
for line in op: | |
outputtxt += tok.decode(line) + "\n" | |
return outputtxt[outputtxt.index('### Assistant:') + 4:] | |
defaultPrompt = "A chat between a curious human and an assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." | |
prompter = gr.inputs.Textbox(default=defaultPrompt, label="Pre-prompt") | |
questioner = gr.inputs.Textbox(default="What museums should I visit? - My kids are aged 12 and 5", label="{Question}? - {Context}") | |
iface = gr.Interface( | |
fn=query, | |
inputs=[ | |
prompter, | |
questioner, | |
], | |
outputs=[ | |
gr.Markdown(value="", label="Response"), | |
], | |
title="LLaMa2-7b fine-tuned on the AskNYC subreddit", | |
description="Dataset is 13k Q&A from 2015 - June 2019, fine-tuning done with QLoRA. Use the '{Question}? - {Context}' format for longer, less snarky results", | |
allow_flagging="never", | |
) | |
iface.launch() | |