ID2223 / app.py
EPark25's picture
Add placeholders
1729e8b
import gradio as gr
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextIteratorStreamer
import chess
import chess.svg
import time
from PIL import Image
import io
import cairosvg
from threading import Thread
import torch
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# LLM
model_id = "unsloth/Llama-3.2-3B-Instruct"
peft_model_id = "EPark25/llama_chess"
# peft_model_id = "samlama111/lora_model"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=peft_model_id,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
model = model.to(device)
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3.1",
mapping={
"role": "from",
"content": "value",
"user": "human",
"assistant": "gpt",
}, # ShareGPT style
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
) # for removing the beginning such as <assistant>... and <eos> tokens
with gr.Blocks() as demo:
def display_board(board):
svg_data = chess.svg.board(board)
png_data = cairosvg.svg2png(bytestring=svg_data)
image = Image.open(io.BytesIO(png_data))
return image
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox(
submit_btn=True, label="Chat", show_label=True, placeholder="Who are you?"
)
chess_moves = gr.Textbox(
submit_btn=True,
label="Chess",
show_label=True,
placeholder="1. e4 c5 2. Nf3 e6 3. d4 d5 4. exd5 exd5 5. Ne5 a6 6. Qh5 Nf6 7. Qxf7# 1-0",
)
chess_board = gr.Image(display_board(chess.Board()))
clear = gr.Button("Clear")
def show_moves(user_message):
moves_made = [m for m in user_message.split() if not m.endswith(".")]
board = chess.Board()
# exclude the end result
moves_made = moves_made[:-1]
for move_made in moves_made:
board.push_san(move_made)
time.sleep(0.2)
yield display_board(board)
return display_board(board)
def chat_user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def chess_user(user_message, history: list):
return user_message, history + [
{"role": "user", "content": f"Guess the elo of this game: {user_message}"}
]
def chatbot_response(history: list):
prompt = tokenizer.apply_chat_template(
history,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer([prompt], return_tensors="pt").to(device)
generate_kwargs = dict(
inputs,
do_sample=True,
streamer=streamer,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
history.append({"role": "assistant", "content": ""})
for text in streamer:
history[-1]["content"] += text
yield history
return history
msg.submit(chat_user, [msg, chatbot], [msg, chatbot]).then(
chatbot_response, chatbot, chatbot
)
chess_moves.submit(chess_user, [chess_moves, chatbot], [chess_moves, chatbot]).then(
chatbot_response, chatbot, chatbot
).then(show_moves, chess_moves, chess_board)
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
demo.launch()