Phi-3-Medium / app.py
Walmart-the-bag's picture
Finally it works (#6)
803d521 verified
raw
history blame contribute delete
No virus
6.1 kB
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
from transformers import StoppingCriteria, StoppingCriteriaList
import torch
import spaces
import os
import subprocess
# Install flash attention
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
model_name = "microsoft/Phi-3-medium-4k-instruct"
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, _attn_implementation="flash_attention_2", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(queue=False)
def predict1(message, history, temperature1, max_tokens1, repetition_penalty1, top_p1):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens1,
do_sample=True,
top_p=top_p1,
repetition_penalty=repetition_penalty1,
temperature=temperature1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
model_name = "microsoft/Phi-3-medium-128k-instruct"
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, _attn_implementation="flash_attention_2", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(queue=False)
def predict(message, history, temperature, max_tokens, repetition_penalty, top_p):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
max_tokens1 = gr.Slider(
minimum=512,
maximum=4096,
value=4000,
step=32,
interactive=True,
label="Maximum number of new tokens to generate",
)
repetition_penalty1 = gr.Slider(
minimum=0.01,
maximum=5.0,
value=1,
step=0.01,
interactive=True,
label="Repetition penalty",
)
temperature1 = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.05,
visible=True,
interactive=True,
label="Temperature",
)
top_p1 = gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.9,
step=0.01,
visible=True,
interactive=True,
label="Top P",
)
chatbot1 = gr.Chatbot(
label="Phi3-medium-4k",
show_copy_button=True,
likeable=True,
layout="panel"
)
output=gr.Textbox(label="Prompt")
with gr.Blocks() as min:
gr.ChatInterface(
fn=predict1,
chatbot=chatbot1,
additional_inputs=[
temperature1,
max_tokens1,
repetition_penalty1,
top_p1,
],
)
max_tokens = gr.Slider(
minimum=64000,
maximum=128000,
value=100000,
step=1000,
interactive=True,
label="Maximum number of new tokens to generate",
)
repetition_penalty = gr.Slider(
minimum=0.01,
maximum=5.0,
value=1,
step=0.01,
interactive=True,
label="Repetition penalty",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.05,
visible=True,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.9,
step=0.01,
visible=True,
interactive=True,
label="Top P",
)
chatbot = gr.Chatbot(
label="Phi3-medium-128k",
show_copy_button=True,
likeable=True,
layout="panel"
)
output=gr.Textbox(label="Prompt")
with gr.Blocks() as max:
gr.ChatInterface(
fn=predict,
chatbot=chatbot,
additional_inputs=[
temperature,
max_tokens,
repetition_penalty,
top_p,
],
)
with gr.Blocks(title="Phi 3 Medium DEMO", theme=theme) as demo:
gr.Markdown("# Phi3 Medium all in one")
gr.TabbedInterface([max, min], ['Phi3 medium 128k','Phi3 medium 4k'])
demo.launch(share=True)