|
import gradio as gr |
|
from unsloth import FastLanguageModel |
|
from unsloth.chat_templates import get_chat_template |
|
from transformers import TextIteratorStreamer |
|
import chess |
|
import chess.svg |
|
import time |
|
from PIL import Image |
|
import io |
|
import cairosvg |
|
from threading import Thread |
|
import torch |
|
|
|
""" |
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
|
""" |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_id = "unsloth/Llama-3.2-3B-Instruct" |
|
peft_model_id = "EPark25/llama_chess" |
|
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=peft_model_id, |
|
max_seq_length=2048, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
|
|
model = model.to(device) |
|
|
|
tokenizer = get_chat_template( |
|
tokenizer, |
|
chat_template="llama-3.1", |
|
mapping={ |
|
"role": "from", |
|
"content": "value", |
|
"user": "human", |
|
"assistant": "gpt", |
|
}, |
|
) |
|
|
|
FastLanguageModel.for_inference(model) |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer, skip_prompt=True, skip_special_tokens=True |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
def display_board(board): |
|
svg_data = chess.svg.board(board) |
|
png_data = cairosvg.svg2png(bytestring=svg_data) |
|
image = Image.open(io.BytesIO(png_data)) |
|
return image |
|
|
|
chatbot = gr.Chatbot(type="messages") |
|
msg = gr.Textbox( |
|
submit_btn=True, label="Chat", show_label=True, placeholder="Who are you?" |
|
) |
|
chess_moves = gr.Textbox( |
|
submit_btn=True, |
|
label="Chess", |
|
show_label=True, |
|
placeholder="1. e4 c5 2. Nf3 e6 3. d4 d5 4. exd5 exd5 5. Ne5 a6 6. Qh5 Nf6 7. Qxf7# 1-0", |
|
) |
|
chess_board = gr.Image(display_board(chess.Board())) |
|
|
|
clear = gr.Button("Clear") |
|
|
|
def show_moves(user_message): |
|
moves_made = [m for m in user_message.split() if not m.endswith(".")] |
|
board = chess.Board() |
|
|
|
|
|
moves_made = moves_made[:-1] |
|
for move_made in moves_made: |
|
board.push_san(move_made) |
|
time.sleep(0.2) |
|
yield display_board(board) |
|
|
|
return display_board(board) |
|
|
|
def chat_user(user_message, history: list): |
|
return "", history + [{"role": "user", "content": user_message}] |
|
|
|
def chess_user(user_message, history: list): |
|
return user_message, history + [ |
|
{"role": "user", "content": f"Guess the elo of this game: {user_message}"} |
|
] |
|
|
|
def chatbot_response(history: list): |
|
prompt = tokenizer.apply_chat_template( |
|
history, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
) |
|
|
|
inputs = tokenizer([prompt], return_tensors="pt").to(device) |
|
|
|
generate_kwargs = dict( |
|
inputs, |
|
do_sample=True, |
|
streamer=streamer, |
|
) |
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
history.append({"role": "assistant", "content": ""}) |
|
|
|
for text in streamer: |
|
history[-1]["content"] += text |
|
yield history |
|
|
|
return history |
|
|
|
msg.submit(chat_user, [msg, chatbot], [msg, chatbot]).then( |
|
chatbot_response, chatbot, chatbot |
|
) |
|
|
|
chess_moves.submit(chess_user, [chess_moves, chatbot], [chess_moves, chatbot]).then( |
|
chatbot_response, chatbot, chatbot |
|
).then(show_moves, chess_moves, chess_board) |
|
|
|
clear.click(lambda: None, None, chatbot) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|