ethanrom commited on
Commit
0ea5b31
1 Parent(s): 595877e
Files changed (4) hide show
  1. app.py +165 -0
  2. image.jpg +0 -0
  3. markup.py +5 -0
  4. requirements..txt +2 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from langchain.chains import LLMChain
4
+ from langchain.prompts import PromptTemplate
5
+ from pydantic import BaseModel, Field
6
+ from typing import Any, Optional, Dict, List
7
+ from huggingface_hub import InferenceClient
8
+ from langchain.llms.base import LLM
9
+ from markup import app_intro
10
+ import os
11
+
12
+ HF_token = os.getenv("apiToken")
13
+
14
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
15
+ hf_token = HF_token
16
+ kwargs = {"max_new_tokens":10, "temperature":0.1, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
17
+
18
+ class KwArgsModel(BaseModel):
19
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
20
+
21
+ class CustomInferenceClient(LLM, KwArgsModel):
22
+ model_name: str
23
+ inference_client: InferenceClient
24
+
25
+ def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
26
+ inference_client = InferenceClient(model=model_name, token=hf_token)
27
+ super().__init__(
28
+ model_name=model_name,
29
+ hf_token=hf_token,
30
+ kwargs=kwargs,
31
+ inference_client=inference_client
32
+ )
33
+
34
+ def _call(
35
+ self,
36
+ prompt: str,
37
+ stop: Optional[List[str]] = None
38
+ ) -> str:
39
+ if stop is not None:
40
+ raise ValueError("stop kwargs are not permitted.")
41
+ response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
42
+ response = ''.join(response_gen)
43
+ return response
44
+
45
+ @property
46
+ def _llm_type(self) -> str:
47
+ return "custom"
48
+
49
+ @property
50
+ def _identifying_params(self) -> dict:
51
+ return {"model_name": self.model_name}
52
+
53
+ def check_winner(board):
54
+ for row in board:
55
+ if len(set(row)) == 1 and row[0] != "":
56
+ return row[0]
57
+ for col in board.T:
58
+ if len(set(col)) == 1 and col[0] != "":
59
+ return col[0]
60
+ if len(set(board.diagonal())) == 1 and board[0, 0] != "":
61
+ return board[0, 0]
62
+ if len(set(np.fliplr(board).diagonal())) == 1 and board[0, 2] != "":
63
+ return board[0, 2]
64
+ return None
65
+
66
+ def check_draw(board):
67
+ return not np.any(board == "")
68
+
69
+ def main():
70
+ st.set_page_config(page_title="Tic Tac Toe", page_icon=":memo:", layout="wide")
71
+
72
+ col1, col2 = st.columns([1, 2])
73
+ with col1:
74
+ st.image("image.jpg", use_column_width=True)
75
+ with col2:
76
+ st.markdown(app_intro(), unsafe_allow_html=True)
77
+
78
+ st.markdown("____")
79
+
80
+ scores = st.session_state.get("scores", {"X": 0, "O": 0})
81
+ board = np.array(st.session_state.get("board", [["" for _ in range(3)] for _ in range(3)]))
82
+ current_player = st.session_state.get("current_player", "X")
83
+ winner = check_winner(board)
84
+
85
+ if winner is not None:
86
+ scores[winner] += 1
87
+ st.write(f"Player {winner} wins! Score: X - {scores['X']} | O - {scores['O']}")
88
+ elif check_draw(board):
89
+ st.write("Draw!")
90
+ else:
91
+ for row in range(3):
92
+ cols = st.columns(3)
93
+ for col in range(3):
94
+ button_key = f"button_{row}_{col}"
95
+ if board[row, col] == "" and current_player == "X":
96
+ if cols[col].button(" ", key=button_key):
97
+ board[row, col] = current_player
98
+ st.session_state.board = board
99
+ st.session_state.current_player = "O"
100
+ progress = st.session_state.get("progress", [])
101
+ progress.append(f"{current_player}: {chr(65 + row)}{col + 1}")
102
+ st.session_state.progress = progress
103
+ st.experimental_rerun()
104
+ else:
105
+ cols[col].write(board[row, col])
106
+
107
+ if current_player == "O" and winner is None:
108
+ with st.spinner("Calculating AI Move..."):
109
+ ai_progress = ", ".join(st.session_state.progress)
110
+ ai_move = get_ai_move(ai_progress)
111
+ ai_row, ai_col = ai_move.split(": ")[1]
112
+ ai_row = ord(ai_row[0]) - 65
113
+ ai_col = int(ai_col) - 1
114
+ board[ai_row, ai_col] = "O"
115
+ st.session_state.board = board
116
+
117
+ progress = st.session_state.get("progress", [])
118
+ progress.append(f"O: {chr(65 + ai_row)}{ai_col + 1}")
119
+ st.session_state.progress = progress
120
+
121
+ st.session_state.current_player = "X"
122
+ st.experimental_rerun()
123
+
124
+ st.markdown("____")
125
+ if st.button("Reset game"):
126
+ st.session_state.board = [["" for _ in range(3)] for _ in range(3)]
127
+ st.session_state.current_player = "X"
128
+ st.session_state.progress = []
129
+ st.experimental_rerun()
130
+
131
+ if st.button("Reset Scores"):
132
+ scores["X"] = 0
133
+ scores["O"] = 0
134
+ st.session_state.scores = scores
135
+ st.experimental_rerun()
136
+
137
+
138
+ progress = st.session_state.get("progress", [])
139
+ st.write(", ".join(progress))
140
+
141
+ st.write(f"Score: X - {scores['X']} | O - {scores['O']}")
142
+
143
+ def get_ai_move(progress):
144
+ print("progress", progress)
145
+ llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
146
+
147
+ template = """<s>[INST] Decide the next O move in tic tac toe game[/INST]
148
+ Example output format = O: B2
149
+
150
+ {progress}
151
+
152
+ Next move:"""
153
+
154
+ prompt = PromptTemplate(template=template, input_variables=["progress"])
155
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
156
+
157
+ answer = llm_chain.run(progress)
158
+ answer = answer.replace("</s>", "")
159
+
160
+ print("ai move", answer)
161
+ return answer
162
+
163
+
164
+ if __name__ == "__main__":
165
+ main()
image.jpg ADDED
markup.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def app_intro():
2
+ return """
3
+
4
+ <h1>🎮 Tic Tac Toe with Mistral 7B</h1>
5
+ """
requirements..txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ langchain
2
+ huggingface_hub