MPT7BTest / app.py
danavirtual
added config (torch) to app
eb0d58e
import gradio as gr
import requests
import torch
import transformers
import einops
###
from typing import Any, Dict, Tuple
import warnings
import datetime
import os
from threading import Event, Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import config
config.init_device="meta"
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
PROMPT_FOR_GENERATION_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
)
##
from InstructionTextGenerationPipeline import *
from timeit import default_timer as timer
import time
import datetime
from datetime import datetime
import json
# create some interactive controls
import sys
import os
import os.path as osp
import pprint
pp = pprint.PrettyPrinter(indent=4)
LIBRARY_PATH = "/home/ec2-user/workspace/Notebooks/lib"
module_path = os.path.abspath(os.path.join(LIBRARY_PATH))
if module_path not in sys.path:
sys.path.append(module_path)
print (f"sys.path : {sys.path}")
def complete(state="complete"):
print(f"\nCell {state}")
complete(state='imports done')
complete(state="start generate")
name = 'mosaicml/mpt-7b-instruct'
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'torch'
config.init_device = 'cuda:0' # For fast initialization directly on GPU!
generate = InstructionTextGenerationPipeline(
name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
config=config,
)
stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
complete(state="Model generated")
# Define a custom stopping criteria
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def process_stream(instruction, temperature, top_p, top_k, max_new_tokens):
# Tokenize the input
input_ids = generate.tokenizer(
generate.format_instruction(instruction), return_tensors="pt"
).input_ids
input_ids = input_ids.to(generate.model.device)
# Initialize the streamer and stopping criteria
streamer = TextIteratorStreamer(
generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop = StopOnTokens()
if temperature < 0.1:
temperature = 0.0
do_sample = False
else:
do_sample = True
gkw = {
**generate.generate_kwargs,
**{
"input_ids": input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"top_p": top_p,
"top_k": top_k,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([stop]),
},
}
response = ''
def generate_and_signal_complete():
generate.model.generate(**gkw)
t1 = Thread(target=generate_and_signal_complete)
t1.start()
for new_text in streamer:
response += new_text
return response
gr.close_all()
def tester(uPrompt, max_new_tokens, temperature, top_k, top_p):
salutation = uPrompt
response = process_stream(uPrompt, temperature, top_p, top_k, max_new_tokens)
results = f"{salutation} max_new_tokens{max_new_tokens}; temperature{temperature}; top_k{top_k}; top_p{top_p}; "
return response
import torch
import transformers
demo = gr.Interface(
fn=tester,
inputs=[gr.Textbox(label="Prompt",info="Prompt",lines=3,value="Provide Prompt"),
gr.Slider(256, 3072,value=1024, step=256, label="Tokens" ),
gr.Slider(0.0, 1.0, value=0.1, step=0.1, label='temperature:'),
gr.Slider(0, 1, value=0, step=1, label='top_k:'),
gr.Slider(0.0, 1.0, value=0.05, step=0.05, label='top_p:')
],
outputs=["text"],
title="Mosaic MPT-7B",
)
demo.launch(share=True,
server_name="0.0.0.0",
server_port=7860
)
# Note on how we can run on SSL
# See:https://github.com/gradio-app/gradio/issues/563
# a = gr.Interface(lambda x:x, "image", "image", examples=["lion.jpg"]).launch(
# share=False, ssl_keyfile="key.pem", ssl_certfile="cert.pem")
# seems like we need an appropriate NON SELF SIGNED cert that the customer will accept on their net