|
|
|
|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import torch |
|
|
import os |
|
|
from huggingface_hub import login |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
|
from peft import PeftModel, PeftConfig |
|
|
import io |
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
|
|
|
|
|
|
try: |
|
|
login(token=os.getenv("HUGGINGFACEHUB_TOKEN")) |
|
|
except Exception as e: |
|
|
st.error(f"Error logging in to Hugging Face: {str(e)}") |
|
|
st.stop() |
|
|
|
|
|
st.set_page_config(page_title="AnthroBot", page_icon="π€", layout="centered") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
try: |
|
|
peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora") |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
peft_config.base_model_name_or_path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
token=True |
|
|
) |
|
|
model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora") |
|
|
model.eval() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
peft_config.base_model_name_or_path, |
|
|
trust_remote_code=True, |
|
|
token=True |
|
|
) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
st.write("β
Model and tokenizer loaded successfully.") |
|
|
return model, tokenizer |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {str(e)}") |
|
|
raise e |
|
|
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
if 'history' not in st.session_state: |
|
|
st.session_state.history = [] |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def generate_response(age, sex, height_cm, weight_kg, wc_cm): |
|
|
try: |
|
|
|
|
|
prompt = f"Age: {age}, Sex: {sex}, Height: {height_cm} cm, Weight: {weight_kg} kg, WC: {wc_cm} cm" |
|
|
st.write(f"π Prompt Sent to Model: `{prompt}`") |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
|
|
|
try: |
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=True, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
max_length=512, |
|
|
truncation=True, |
|
|
return_dict=True |
|
|
) |
|
|
except Exception as e: |
|
|
st.warning(f"apply_chat_template failed: {str(e)}. Falling back to manual tokenization.") |
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
max_length=512, |
|
|
truncation=True, |
|
|
padding=False, |
|
|
return_attention_mask=True |
|
|
) |
|
|
|
|
|
|
|
|
st.write(f"Inputs type: {type(inputs)}") |
|
|
st.write(f"Inputs keys: {list(inputs.keys()) if isinstance(inputs, (dict, BatchEncoding)) else 'N/A'}") |
|
|
|
|
|
|
|
|
if isinstance(inputs, (dict, BatchEncoding)): |
|
|
input_ids = inputs['input_ids'] |
|
|
attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids)) |
|
|
elif isinstance(inputs, torch.Tensor): |
|
|
input_ids = inputs |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
else: |
|
|
st.error(f"Unexpected inputs format: {type(inputs)}") |
|
|
return None |
|
|
|
|
|
|
|
|
if len(input_ids.shape) == 1: |
|
|
input_ids = input_ids.unsqueeze(0) |
|
|
attention_mask = attention_mask.unsqueeze(0) |
|
|
elif len(input_ids.shape) > 2: |
|
|
input_ids = input_ids.squeeze() |
|
|
attention_mask = attention_mask.squeeze() |
|
|
if len(input_ids.shape) == 1: |
|
|
input_ids = input_ids.unsqueeze(0) |
|
|
attention_mask = attention_mask.unsqueeze(0) |
|
|
|
|
|
st.write(f"Input IDs shape: {input_ids.shape}") |
|
|
st.write(f"Attention mask shape: {attention_mask.shape}") |
|
|
|
|
|
|
|
|
input_ids = input_ids.to(device) |
|
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
|
|
|
st.write("π€ Model response:") |
|
|
with st.empty(): |
|
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) |
|
|
output = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=250, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
use_cache=True, |
|
|
streamer=text_streamer |
|
|
) |
|
|
|
|
|
|
|
|
decoded = tokenizer.decode(output[0], skip_special_tokens=False) |
|
|
st.write(f"Decoded output: {decoded}") |
|
|
|
|
|
|
|
|
st.session_state.history.append((prompt, decoded)) |
|
|
return decoded |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error during generation: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
st.title("π§ AnthroBot") |
|
|
st.markdown("Enter your anthropometric details to receive an AI-generated summary of health metrics.") |
|
|
|
|
|
|
|
|
tab1, tab2 = st.tabs(["π§ Manual Input", "π CSV Upload"]) |
|
|
|
|
|
with tab1: |
|
|
st.subheader("Manual Entry") |
|
|
age = st.number_input("Age", min_value=1, max_value=120, value=30) |
|
|
sex = st.selectbox("Sex", options=["male", "female"]) |
|
|
height = st.number_input("Height (cm)", min_value=50.0, max_value=250.0, value=170.0) |
|
|
weight = st.number_input("Weight (kg)", min_value=10.0, max_value=300.0, value=70.0) |
|
|
wc = st.number_input("Waist Circumference (cm)", min_value=20.0, max_value=200.0, value=80.0) |
|
|
|
|
|
if st.button("Estimate Metrics"): |
|
|
prediction = generate_response(age, sex, height, weight, wc) |
|
|
if prediction: |
|
|
st.success("Prediction:") |
|
|
st.write(prediction) |
|
|
|
|
|
|
|
|
st.subheader("Prediction History") |
|
|
for prompt, response in st.session_state.history: |
|
|
st.markdown(f"**Input**: {prompt}") |
|
|
st.markdown(f"**Output**: {response}") |
|
|
|
|
|
with tab2: |
|
|
st.subheader("Batch Upload via CSV") |
|
|
sample_csv = pd.DataFrame({ |
|
|
"Age": [30], |
|
|
"Sex": ["male"], |
|
|
"Height": [170.0], |
|
|
"Weight": [70.0], |
|
|
"WC": [80.0] |
|
|
}) |
|
|
|
|
|
st.download_button("π₯ Download Sample CSV", sample_csv.to_csv(index=False), file_name="sample_input.csv") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) |
|
|
|
|
|
if uploaded_file: |
|
|
df = pd.read_csv(uploaded_file) |
|
|
if not all(col in df.columns for col in ["Age", "Sex", "Height", "Weight", "WC"]): |
|
|
st.error("CSV must contain columns: Age, Sex, Height, Weight, WC") |
|
|
else: |
|
|
outputs = [] |
|
|
with st.spinner("Generating predictions..."): |
|
|
for _, row in df.iterrows(): |
|
|
prediction = generate_response(row['Age'], row['Sex'], row['Height'], row['Weight'], row['WC']) |
|
|
outputs.append(prediction if prediction else "Error") |
|
|
|
|
|
df["Prediction"] = outputs |
|
|
st.success("Here are your predictions:") |
|
|
st.dataframe(df) |
|
|
|
|
|
csv_output = df.to_csv(index=False).encode("utf-8") |
|
|
st.download_button("π€ Download Predictions", data=csv_output, file_name="predictions.csv") |
|
|
|
|
|
|
|
|
if st.button("Clear History"): |
|
|
st.session_state.history = [] |
|
|
st.rerun() |
|
|
|
|
|
|