Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import json | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments | |
from datasets import Dataset | |
import torch | |
from huggingface_hub import HfFolder | |
import subprocess | |
from AppointmentScheduler import AppointmentScheduler | |
# Authenticate Hugging Face Hub | |
hf_token = st.secrets["HF_TOKEN"] | |
HfFolder.save_token(hf_token) | |
def set_git_config(): | |
try: | |
subprocess.run(['git', 'config', '--global', 'user.email', 'nilesh.hanotia@outlook.com'], check=True) | |
subprocess.run(['git', 'config', '--global', 'user.name', 'Nilesh'], check=True) | |
st.success("Git configuration set successfully.") | |
except subprocess.CalledProcessError as e: | |
st.error(f"Git configuration error: {str(e)}") | |
set_git_config() | |
def load_data(file_paths): | |
combined_data = [] | |
for file_path in file_paths: | |
file_path = file_path.strip() | |
if not os.path.exists(file_path): | |
st.error(f"File not found: {file_path}") | |
return None | |
try: | |
with open(file_path, 'r') as f: | |
data = json.load(f) | |
# Add a print to inspect the data structure | |
print(f"Data loaded from {file_path}: {data}") | |
# Assuming you're expecting 'intents' with 'examples' | |
if 'intents' in data: | |
for intent in data['intents']: | |
combined_data.extend(intent['examples']) | |
else: | |
st.error(f"Invalid format in file: {file_path}") | |
return None | |
except Exception as e: | |
st.error(f"Error loading dataset from {file_path}: {str(e)}") | |
return None | |
print(f"Combined data: {combined_data}") # Check the combined dataset | |
return combined_data | |
def initialize_model_and_tokenizer(model_name, num_labels): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
# Set the padding token to the EOS token | |
tokenizer.pad_token = tokenizer.eos_token | |
# Update the model config | |
model.config.pad_token_id = tokenizer.pad_token_id | |
# Resize the token embeddings as we added a new token | |
model.resize_token_embeddings(len(tokenizer)) | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error initializing model and tokenizer: {str(e)}") | |
return None, None | |
def create_dataset(data, tokenizer, max_length, num_labels): | |
texts = [item.get('prompt', '') for item in data if item.get('prompt')] | |
labels = [item.get('label', 0) for item in data if item.get('prompt')] | |
if not texts: | |
raise ValueError("The input texts list is empty. Please check your data.") | |
# Ensure all labels are within the valid range | |
labels = [label if 0 <= label < num_labels else 0 for label in labels] | |
# Tokenize the input texts with proper padding and truncation | |
encodings = tokenizer( | |
texts, | |
truncation=True, | |
padding='max_length', | |
max_length=max_length, | |
return_tensors='pt' | |
) | |
dataset = Dataset.from_dict({ | |
'input_ids': encodings['input_ids'], | |
'attention_mask': encodings['attention_mask'], | |
'labels': labels | |
}) | |
return dataset | |
def main(): | |
st.title("Appointment Scheduling Platform") | |
model_name = st.text_input("Enter model name", "distilgpt2") | |
file_paths = st.text_area("Enter training data paths").split(',') | |
max_length = st.number_input("Max token length", 128) | |
num_epochs = st.number_input("Training epochs", 3) | |
batch_size = st.number_input("Batch size", 8) | |
learning_rate = st.number_input("Learning rate", 5e-5) | |
num_labels = 3 | |
repo_id = st.text_input("Hugging Face Repo ID", "nileshhanotia/PeVe") | |
tokenizer, model = initialize_model_and_tokenizer(model_name, num_labels) | |
if tokenizer and model: | |
data = load_data(file_paths) | |
if data: | |
print(f"Total data loaded: {len(data)}") | |
print(f"Sample data item: {data[0] if data else 'No data'}") | |
train_data, eval_data = data[:int(len(data)*0.8)], data[int(len(data)*0.8):] | |
print(f"Train data size: {len(train_data)}, Eval data size: {len(eval_data)}") | |
train_dataset = create_dataset(train_data, tokenizer, max_length, num_labels) | |
eval_dataset = create_dataset(eval_data, tokenizer, max_length, num_labels) | |
print(f"Train dataset size: {len(train_dataset)}, Eval dataset size: {len(eval_dataset)}") | |
print(f"Sample train item: {train_dataset[0] if train_dataset else 'No data'}") | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy='epoch', | |
learning_rate=learning_rate, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=num_epochs, | |
logging_dir='./logs', | |
push_to_hub=True, | |
hub_model_id=repo_id, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
) | |
if st.button('Start Training'): | |
st.write("Training model...") | |
trainer.train() | |
trainer.push_to_hub() | |
st.write(f"Model pushed to: {repo_id}") | |
# Integrate AppointmentScheduler | |
st.header("Appointment Scheduler") | |
# Initialize session state for conversation history and scheduler | |
if 'conversation_history' not in st.session_state: | |
st.session_state.conversation_history = [] | |
st.session_state.scheduler = AppointmentScheduler() | |
st.session_state.first_interaction = True | |
user_input = st.text_input("Enter patient response") | |
if user_input: | |
# If it's the first interaction, start with the greeting | |
if st.session_state.first_interaction: | |
response = st.session_state.scheduler.handle_incoming_speech("hello") | |
st.session_state.conversation_history.append(("Assistant", response)) | |
st.session_state.first_interaction = False | |
# Use AppointmentScheduler to handle the response | |
response = st.session_state.scheduler.handle_incoming_speech(user_input) | |
st.session_state.conversation_history.append(("Patient", user_input)) | |
st.session_state.conversation_history.append(("Assistant", response)) | |
# Display conversation history | |
for speaker, message in st.session_state.conversation_history: | |
st.write(f"{speaker}: {message}") | |
if __name__ == "__main__": | |
main() | |