File size: 2,347 Bytes
e8a8dd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import streamlit as st
from pages import set_app_title_and_logo, qb_gpt_page, contacts_and_disclaimers
import json
import pandas as pd
import numpy as np
import os
from tools import tokenizer
from assets.models import QBGPT
moves_to_pred = 11170
input_size = 11172
starts_size = 1954
scrimmage_size = 100
positions_id = 29
temp_ids = 52
off_def_size = 2
token_type_size = 3
play_type_size = 9
qbgpt = QBGPT(input_vocab_size = input_size,
positional_vocab_size = temp_ids,
position_vocab_size=positions_id,
start_vocab_size=starts_size,
scrimmage_vocab_size=scrimmage_size,
offdef_vocab_size = off_def_size,
type_vocab_size = token_type_size,
playtype_vocab_size = play_type_size,
embedding_dim = 256,
hidden_dim = 256,
num_heads = 3,
diag_masks = False,
to_pred_size = moves_to_pred)
qbgpt.load_weights("app/assets/model_mediumv2/QBGPT")
qb_tok = tokenizer(moves_index="./app/assets/moves_index.parquet",
play_index="./app/assets/plays_index.parquet",
positions_index="./app/assets/positions_index.parquet",
scrimmage_index="./app/assets/scrimmage_index.parquet",
starts_index="./app/assets/starts_index.parquet",
time_index="./app/assets/time_index.parquet",
window_size=20)
print(os.listdir("app"))
with open('./app/assets/ref.json', 'r') as fp:
ref_json = json.load(fp)
def convert_numpy(d):
return {k:np.array(v) for k,v in d.items()}
ref_json = {int(k):convert_numpy(v) for k,v in ref_json.items()}
ref_df = pd.read_json("./app/assets/ref_df.json")
# Define the main function to run the app
def main():
set_app_title_and_logo()
# Create a sidebar for navigation
st.sidebar.title("Navigation")
page = st.sidebar.radio("Go to:", ("QB-GPT", "Contacts and Disclaimers"))
if page == "QB-GPT":
# Page 2: QB-GPT
st.title("QB-GPT")
qb_gpt_page(ref_df, ref_json, qb_tok, qbgpt)
if page == "Contacts and Disclaimers":
contacts_and_disclaimers()
if __name__ == "__main__":
main() |