Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from typing import List | |
import streamlit as st | |
from landingai.common import Prediction | |
from landingai.predict import Predictor | |
from landingai.st_utils import render_svg, setup_page | |
from landingai.visualize import overlay_predictions | |
from PIL import Image | |
import simulation as s | |
from holdem import run_simulation | |
setup_page(page_title="LandingLens Holdem Odds Calculator") | |
Image.MAX_IMAGE_PIXELS = None | |
_HAND = "hand" | |
_FLOP = "flop" | |
def main(): | |
st.markdown(""" | |
<h3 align="center">Holdem Odds Calculator</h3> | |
<p align="center"> | |
<img width="100" height="100" src="https://github.com/landing-ai/landingai-python/raw/main/assets/avi-logo.png"> | |
</p> | |
""", unsafe_allow_html=True) | |
tab1, tab2 = st.tabs(["Your hand", "Flop"]) | |
with tab1: | |
image_file_hand = st.file_uploader("Your hand") | |
# image_file_hand = image_file = st.camera_input("Your hand") | |
if image_file_hand is not None: | |
preds = inference(image_file_hand) | |
if len(preds) != 2: | |
_show_error_message(2, preds, image_file_hand, "hand") | |
image_file_hand = None | |
return | |
st.session_state[_HAND] = preds, image_file_hand | |
with tab2: | |
image_file_flop = st.file_uploader("Flop") | |
# image_file_flop = image_file = st.camera_input(label="Flop") | |
if image_file_flop is not None: | |
preds = inference(image_file_flop) | |
if len(preds) != 3: | |
_show_error_message(3, preds, image_file_flop, "flop") | |
image_file_flop = None | |
return | |
st.session_state[_FLOP] = preds, image_file_flop | |
if _HAND not in st.session_state: | |
st.info("Please take a photo of your hand.") | |
return | |
if _FLOP not in st.session_state: | |
_show_predictions(*st.session_state[_HAND], "Your hand") | |
hand = [_convert_name(det.label_name) for det in st.session_state[_HAND][0]] | |
run_simulation(hand=hand) | |
st.info("Please take a photo of the flop to continue.") | |
return | |
col1, col2 = st.columns(2) | |
with col1: | |
_show_predictions(*st.session_state[_HAND], "Your hand") | |
with col2: | |
_show_predictions(*st.session_state[_FLOP], "Flop") | |
hand = [_convert_name(det.label_name) for det in st.session_state[_HAND][0]] | |
flop = [_convert_name(det.label_name) for det in st.session_state[_FLOP][0]] | |
if not _validate_cards(hand, flop): | |
return | |
run_simulation(hand=hand, flop=flop) | |
st.write("Interested in building more Computer Vision applications? Check out our [LandingLens](https://landing.ai/) platform and our [open source Python SDK](https://github.com/landing-ai/landingai-python)!") | |
def _validate_cards(hand, flop) -> bool: | |
check = hand + flop | |
if s.dedup(check): | |
st.error( | |
"There is a duplicate card. Please check the board and your hand and try again.", | |
icon="π¨", | |
) | |
return False | |
if not s.validate_card(check): | |
st.error( | |
"At least one of your cards is not valid. Please try again.", icon="π¨" | |
) | |
return False | |
return True | |
def _convert_name(name: str) -> str: | |
if name.startswith("10"): | |
return f"T{name[2:].lower()}" | |
else: | |
return f"{name[0].upper()}{name[1:].lower()}" | |
# TODO Rename this here and in `main` | |
def _show_error_message(expected_len, preds, img_file, pred_name): | |
msg = f"Detected {len(preds)}, expects {expected_len} cards in your {pred_name}. Please try again with a new photo." | |
st.error(msg) | |
_show_predictions(preds, img_file, pred_name) | |
def inference(image_file) -> List[Prediction]: | |
image = Image.open(image_file).convert("RGB") | |
predictor = Predictor( | |
endpoint_id="2549edc1-35ad-45aa-b27e-f49e79f5922e", | |
api_key="land_sk_JkygHlib8SgryZUgumM6r8GWYfQqiKdE36xDzo4K85fDihpnuG", | |
) | |
logging.info("Running Poker prediction") | |
preds = predictor.predict(image) | |
preds = _dedup_preds(preds) | |
logging.info( | |
f"Poker prediction done successfully against {image_file} with {len(preds)} predictions." | |
) | |
return preds | |
def _show_predictions(preds, image_file, caption: str) -> None: | |
image = Image.open(image_file).convert("RGB") | |
display_image = overlay_predictions(preds, image) | |
st.image( | |
display_image, | |
channels="RGB", | |
caption=caption, | |
) | |
def _dedup_preds(preds: List[Prediction]) -> List[Prediction]: | |
"""Deduplicate predictions by the prediction value.""" | |
result = {p.label_name: p for p in preds} | |
return list(result.values()) | |
# Run the app | |
if __name__ == "__main__": | |
main() | |