LandingAI-Poker / app.py
dillonlaird's picture
initial commit
b578b56
raw
history blame
4.72 kB
import os
import logging
from typing import List
import streamlit as st
from landingai.common import Prediction
from landingai.predict import Predictor
from landingai.st_utils import render_svg, setup_page
from landingai.visualize import overlay_predictions
from PIL import Image
import simulation as s
from holdem import run_simulation
setup_page(page_title="LandingLens Holdem Odds Calculator")
Image.MAX_IMAGE_PIXELS = None
_HAND = "hand"
_FLOP = "flop"
def main():
st.markdown("""
<h3 align="center">Holdem Odds Calculator</h3>
<p align="center">
<img width="100" height="100" src="https://github.com/landing-ai/landingai-python/raw/main/assets/avi-logo.png">
</p>
""", unsafe_allow_html=True)
tab1, tab2 = st.tabs(["Your hand", "Flop"])
with tab1:
image_file_hand = st.file_uploader("Your hand")
# image_file_hand = image_file = st.camera_input("Your hand")
if image_file_hand is not None:
preds = inference(image_file_hand)
if len(preds) != 2:
_show_error_message(2, preds, image_file_hand, "hand")
image_file_hand = None
return
st.session_state[_HAND] = preds, image_file_hand
with tab2:
image_file_flop = st.file_uploader("Flop")
# image_file_flop = image_file = st.camera_input(label="Flop")
if image_file_flop is not None:
preds = inference(image_file_flop)
if len(preds) != 3:
_show_error_message(3, preds, image_file_flop, "flop")
image_file_flop = None
return
st.session_state[_FLOP] = preds, image_file_flop
if _HAND not in st.session_state:
st.info("Please take a photo of your hand.")
return
if _FLOP not in st.session_state:
_show_predictions(*st.session_state[_HAND], "Your hand")
hand = [_convert_name(det.label_name) for det in st.session_state[_HAND][0]]
run_simulation(hand=hand)
st.info("Please take a photo of the flop to continue.")
return
col1, col2 = st.columns(2)
with col1:
_show_predictions(*st.session_state[_HAND], "Your hand")
with col2:
_show_predictions(*st.session_state[_FLOP], "Flop")
hand = [_convert_name(det.label_name) for det in st.session_state[_HAND][0]]
flop = [_convert_name(det.label_name) for det in st.session_state[_FLOP][0]]
if not _validate_cards(hand, flop):
return
run_simulation(hand=hand, flop=flop)
st.write("Interested in building more Computer Vision applications? Check out our [LandingLens](https://landing.ai/) platform and our [open source Python SDK](https://github.com/landing-ai/landingai-python)!")
def _validate_cards(hand, flop) -> bool:
check = hand + flop
if s.dedup(check):
st.error(
"There is a duplicate card. Please check the board and your hand and try again.",
icon="🚨",
)
return False
if not s.validate_card(check):
st.error(
"At least one of your cards is not valid. Please try again.", icon="🚨"
)
return False
return True
def _convert_name(name: str) -> str:
if name.startswith("10"):
return f"T{name[2:].lower()}"
else:
return f"{name[0].upper()}{name[1:].lower()}"
# TODO Rename this here and in `main`
def _show_error_message(expected_len, preds, img_file, pred_name):
msg = f"Detected {len(preds)}, expects {expected_len} cards in your {pred_name}. Please try again with a new photo."
st.error(msg)
_show_predictions(preds, img_file, pred_name)
@st.cache_data
def inference(image_file) -> List[Prediction]:
image = Image.open(image_file).convert("RGB")
predictor = Predictor(endpoint_id="2549edc1-35ad-45aa-b27e-f49e79f5922e", api_key=os.environ.get("LANDINGAI_API_KEY"))
logging.info("Running Poker prediction")
preds = predictor.predict(image)
preds = _dedup_preds(preds)
logging.info(
f"Poker prediction done successfully against {image_file} with {len(preds)} predictions."
)
return preds
def _show_predictions(preds, image_file, caption: str) -> None:
image = Image.open(image_file).convert("RGB")
display_image = overlay_predictions(preds, image)
st.image(
display_image,
channels="RGB",
caption=caption,
)
def _dedup_preds(preds: List[Prediction]) -> List[Prediction]:
"""Deduplicate predictions by the prediction value."""
result = {p.label_name: p for p in preds}
return list(result.values())
# Run the app
if __name__ == "__main__":
main()