|
import streamlit as st |
|
import numpy as np |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from pydantic.v1 import BaseModel, Field |
|
from typing import Any, Optional, Dict, List |
|
from huggingface_hub import InferenceClient |
|
from langchain.llms.base import LLM |
|
from markup import app_intro |
|
import os |
|
|
|
HF_token = os.getenv("apiToken") |
|
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.1" |
|
hf_token = HF_token |
|
kwargs = {"max_new_tokens":10, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} |
|
|
|
class KwArgsModel(BaseModel): |
|
kwargs: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
class CustomInferenceClient(LLM, KwArgsModel): |
|
model_name: str |
|
inference_client: InferenceClient |
|
|
|
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): |
|
inference_client = InferenceClient(model=model_name, token=hf_token) |
|
super().__init__( |
|
model_name=model_name, |
|
hf_token=hf_token, |
|
kwargs=kwargs, |
|
inference_client=inference_client |
|
) |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None |
|
) -> str: |
|
if stop is not None: |
|
raise ValueError("stop kwargs are not permitted.") |
|
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) |
|
response = ''.join(response_gen) |
|
return response |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
@property |
|
def _identifying_params(self) -> dict: |
|
return {"model_name": self.model_name} |
|
|
|
def check_winner(board): |
|
for row in board: |
|
if len(set(row)) == 1 and row[0] != "": |
|
return row[0] |
|
for col in board.T: |
|
if len(set(col)) == 1 and col[0] != "": |
|
return col[0] |
|
if len(set(board.diagonal())) == 1 and board[0, 0] != "": |
|
return board[0, 0] |
|
if len(set(np.fliplr(board).diagonal())) == 1 and board[0, 2] != "": |
|
return board[0, 2] |
|
return None |
|
|
|
def check_draw(board): |
|
return not np.any(board == "") |
|
|
|
def main(): |
|
st.set_page_config(page_title="Tic Tac Toe", page_icon=":memo:", layout="wide") |
|
|
|
col1, col2 = st.columns([1, 2]) |
|
with col1: |
|
st.image("image.jpg", use_column_width=True) |
|
with col2: |
|
st.markdown(app_intro(), unsafe_allow_html=True) |
|
|
|
st.markdown("____") |
|
|
|
scores = st.session_state.get("scores", {"X": 0, "O": 0}) |
|
board = np.array(st.session_state.get("board", [["" for _ in range(3)] for _ in range(3)])) |
|
current_player = st.session_state.get("current_player", "X") |
|
winner = check_winner(board) |
|
|
|
if winner is not None: |
|
scores[winner] += 1 |
|
st.write(f"Player {winner} wins! Score: X - {scores['X']} | O - {scores['O']}") |
|
elif check_draw(board): |
|
st.write("Draw!") |
|
else: |
|
for row in range(3): |
|
cols = st.columns(3) |
|
for col in range(3): |
|
button_key = f"button_{row}_{col}" |
|
if board[row, col] == "" and current_player == "X": |
|
if cols[col].button(" ", key=button_key): |
|
board[row, col] = current_player |
|
st.session_state.board = board |
|
st.session_state.current_player = "O" |
|
progress = st.session_state.get("progress", []) |
|
progress.append(f"{current_player}: {chr(65 + row)}{col + 1}") |
|
st.session_state.progress = progress |
|
st.experimental_rerun() |
|
else: |
|
cols[col].write(board[row, col]) |
|
|
|
if current_player == "O" and winner is None: |
|
with st.spinner("Calculating AI Move..."): |
|
ai_progress = ", ".join(st.session_state.progress) |
|
ai_move = get_ai_move(ai_progress) |
|
ai_row, ai_col = ai_move.split(": ")[1] |
|
ai_row = ord(ai_row[0]) - 65 |
|
ai_col = int(ai_col) - 1 |
|
board[ai_row, ai_col] = "O" |
|
st.session_state.board = board |
|
|
|
progress = st.session_state.get("progress", []) |
|
progress.append(f"O: {chr(65 + ai_row)}{ai_col + 1}") |
|
st.session_state.progress = progress |
|
|
|
st.session_state.current_player = "X" |
|
st.experimental_rerun() |
|
|
|
st.markdown("____") |
|
if st.button("Reset game"): |
|
st.session_state.board = [["" for _ in range(3)] for _ in range(3)] |
|
st.session_state.current_player = "X" |
|
st.session_state.progress = [] |
|
st.experimental_rerun() |
|
|
|
if st.button("Reset Scores"): |
|
scores["X"] = 0 |
|
scores["O"] = 0 |
|
st.session_state.scores = scores |
|
st.experimental_rerun() |
|
|
|
|
|
progress = st.session_state.get("progress", []) |
|
st.write(", ".join(progress)) |
|
|
|
st.write(f"Score: X - {scores['X']} | O - {scores['O']}") |
|
|
|
def get_ai_move(progress): |
|
print("progress", progress) |
|
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs) |
|
|
|
template = """<s>[INST] Decide the next O move in tic tac toe game[/INST] |
|
Example output format = O: B2 |
|
|
|
{progress} |
|
|
|
Next move:""" |
|
|
|
prompt = PromptTemplate(template=template, input_variables=["progress"]) |
|
llm_chain = LLMChain(prompt=prompt, llm=llm) |
|
|
|
answer = llm_chain.run(progress) |
|
answer = answer.replace("</s>", "") |
|
|
|
print("ai move", answer) |
|
return answer |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|