LandingAI-Poker / app.py
dillonlaird's picture
updated predictor
451d869
import os
import logging
from typing import List
import streamlit as st
from landingai.common import Prediction
from landingai.predict import Predictor
from landingai.st_utils import render_svg, setup_page
from landingai.visualize import overlay_predictions
from PIL import Image
import simulation as s
from holdem import run_simulation
setup_page(page_title="LandingLens Holdem Odds Calculator")
Image.MAX_IMAGE_PIXELS = None
_HAND = "hand"
_FLOP = "flop"
def main():
st.markdown("""
<h3 align="center">Holdem Odds Calculator</h3>
<p align="center">
<img width="100" height="100" src="https://github.com/landing-ai/landingai-python/raw/main/assets/avi-logo.png">
</p>
""", unsafe_allow_html=True)
tab1, tab2 = st.tabs(["Your hand", "Flop"])
with tab1:
image_file_hand = st.file_uploader("Your hand")
# image_file_hand = image_file = st.camera_input("Your hand")
if image_file_hand is not None:
preds = inference(image_file_hand)
if len(preds) != 2:
_show_error_message(2, preds, image_file_hand, "hand")
image_file_hand = None
return
st.session_state[_HAND] = preds, image_file_hand
with tab2:
image_file_flop = st.file_uploader("Flop")
# image_file_flop = image_file = st.camera_input(label="Flop")
if image_file_flop is not None:
preds = inference(image_file_flop)
if len(preds) != 3:
_show_error_message(3, preds, image_file_flop, "flop")
image_file_flop = None
return
st.session_state[_FLOP] = preds, image_file_flop
if _HAND not in st.session_state:
st.info("Please take a photo of your hand.")
return
if _FLOP not in st.session_state:
_show_predictions(*st.session_state[_HAND], "Your hand")
hand = [_convert_name(det.label_name) for det in st.session_state[_HAND][0]]
run_simulation(hand=hand)
st.info("Please take a photo of the flop to continue.")
return
col1, col2 = st.columns(2)
with col1:
_show_predictions(*st.session_state[_HAND], "Your hand")
with col2:
_show_predictions(*st.session_state[_FLOP], "Flop")
hand = [_convert_name(det.label_name) for det in st.session_state[_HAND][0]]
flop = [_convert_name(det.label_name) for det in st.session_state[_FLOP][0]]
if not _validate_cards(hand, flop):
return
run_simulation(hand=hand, flop=flop)
st.write("Interested in building more Computer Vision applications? Check out our [LandingLens](https://landing.ai/) platform and our [open source Python SDK](https://github.com/landing-ai/landingai-python)!")
def _validate_cards(hand, flop) -> bool:
check = hand + flop
if s.dedup(check):
st.error(
"There is a duplicate card. Please check the board and your hand and try again.",
icon="🚨",
)
return False
if not s.validate_card(check):
st.error(
"At least one of your cards is not valid. Please try again.", icon="🚨"
)
return False
return True
def _convert_name(name: str) -> str:
if name.startswith("10"):
return f"T{name[2:].lower()}"
else:
return f"{name[0].upper()}{name[1:].lower()}"
# TODO Rename this here and in `main`
def _show_error_message(expected_len, preds, img_file, pred_name):
msg = f"Detected {len(preds)}, expects {expected_len} cards in your {pred_name}. Please try again with a new photo."
st.error(msg)
_show_predictions(preds, img_file, pred_name)
@st.cache_data
def inference(image_file) -> List[Prediction]:
image = Image.open(image_file).convert("RGB")
predictor = Predictor(
endpoint_id="2549edc1-35ad-45aa-b27e-f49e79f5922e",
api_key="land_sk_JkygHlib8SgryZUgumM6r8GWYfQqiKdE36xDzo4K85fDihpnuG",
)
logging.info("Running Poker prediction")
preds = predictor.predict(image)
preds = _dedup_preds(preds)
logging.info(
f"Poker prediction done successfully against {image_file} with {len(preds)} predictions."
)
return preds
def _show_predictions(preds, image_file, caption: str) -> None:
image = Image.open(image_file).convert("RGB")
display_image = overlay_predictions(preds, image)
st.image(
display_image,
channels="RGB",
caption=caption,
)
def _dedup_preds(preds: List[Prediction]) -> List[Prediction]:
"""Deduplicate predictions by the prediction value."""
result = {p.label_name: p for p in preds}
return list(result.values())
# Run the app
if __name__ == "__main__":
main()