NYC_Savvy_LLaMa / app.py
Nick Doiron
helpful text
1ec27b1
## Adapted from the QLoRA Guanaco demo on Gradio
# https://github.com/artidoro/qlora
# https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
model_name = 'monsoon-nlp/nyc-savvy-llama2-7b'
m = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tok = LlamaTokenizer.from_pretrained(model_name)
tok.bos_token_id = 1
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stop = StopOnTokens()
max_new_tokens = 1536
def query(prompt, chatStart):
messages = prompt.strip() + "\n"
messages += "### Human: " + chatStart.strip()
messages += "### Assistant: "
input_ids = tok(messages, return_tensors="pt").input_ids
input_ids = input_ids.to(m.device)
temperature = 0.7
top_p = 0.9
top_k = 0
repetition_penalty = 1.1
op = m.generate(
input_ids=input_ids,
max_new_tokens=100,
temperature=temperature,
do_sample=temperature > 0.0,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stopping_criteria=StoppingCriteriaList([stop]),
)
outputtxt = ""
for line in op:
outputtxt += tok.decode(line) + "\n"
return outputtxt[outputtxt.index('### Assistant:') + 4:]
defaultPrompt = "A chat between a curious human and an assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
prompter = gr.inputs.Textbox(default=defaultPrompt, label="Pre-prompt")
questioner = gr.inputs.Textbox(default="What museums should I visit? - My kids are aged 12 and 5", label="{Question}? - {Context}")
iface = gr.Interface(
fn=query,
inputs=[
prompter,
questioner,
],
outputs=[
gr.Markdown(value="", label="Response"),
],
title="LLaMa2-7b fine-tuned on the AskNYC subreddit",
description="Dataset is 13k Q&A from 2015 - June 2019, fine-tuning done with QLoRA. Use the '{Question}? - {Context}' format for longer, less snarky results",
allow_flagging="never",
)
iface.launch()