import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import datetime import re import os import pytz import dateutil.parser # Load the DistilBERT model and tokenizer model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased") tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # Initialize an empty list to store events events = [] # Load events from file if it exists if os.path.isfile("events.txt"): with open("events.txt", "r") as f: for line in f: event_data = line.strip().split("|") if len(event_data) == 4: name, start_str, end_str, recurring = event_data start = dateutil.parser.parse(start_str) end = dateutil.parser.parse(end_str) is_recurring = (recurring.lower() == "true") events.append({"name": name, "start": start, "end": end, "recurring": is_recurring}) print(f"Loaded event: {name} ({start} - {end})") def generate_response(prompt): """ Generate a response using the DistilBERT model. """ inputs = tokenizer(prompt, return_tensors="pt") output = model(**inputs)[0] # get the logits return tokenizer.decode(torch.argmax(output, dim=-1)[0], skip_special_tokens=True) def list_events(start, end): """ List events for the day between start and end times. """ event_summaries = [] for event in events: event_start = event["start"] event_end = event["end"] if event_start.tzinfo is None: event_start = pytz.utc.localize(event_start) if event_end.tzinfo is None: event_end = pytz.utc.localize(event_end) if start <= event_start < end: event_summaries.append(f"{event['name']} ({event_start.strftime('%I:%M %p')} - {event_end.strftime('%I:%M %p')})") if not event_summaries: return "There are no events presently." return ", ".join(event_summaries) def create_event(summary, start, end, recurring=False): """ Create a new event. """ event = {"name": summary, "start": start, "end": end, "recurring": recurring} events.append(event) save_events() return f"Event '{summary}' has been scheduled from {start.strftime('%I:%M %p')} to {end.strftime('%I:%M %p')}." def save_events(): """ Save events to a text file. """ with open("events.txt", "w") as f: for event in events: start_str = event["start"].strftime("%Y-%m-%d %H:%M:%S") end_str = event["end"].strftime("%Y-%m-%d %H:%M:%S") recurring_str = "True" if event["recurring"] else "False" f.write(f"{event['name']}|{start_str}|{end_str}|{recurring_str}\n") def process_input(user_input): """ Process the user input and perform the corresponding action. """ if any(keyword in user_input.lower() for keyword in ["schedule", "create"]): summary, start, end, recurring = extract_event_details(user_input) if summary and start and end: response = create_event(summary, start, end, recurring) return response else: return "I'm sorry, I couldn't understand the event details. Please try again." elif any(keyword in user_input.lower() for keyword in ["list", "show"]): start = datetime.datetime.now(pytz.utc).replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=pytz.utc) end = start + datetime.timedelta(days=1, seconds=-1, microseconds=-1) existing_events = list_events(start, end) return existing_events else: return "I'm sorry, I didn't understand your request. Please try again." def extract_event_details(user_input): """ Extract the event summary, start time, end time, and recurrence from the user input. """ patterns = [ r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*tomorrow", r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*(\w+)", r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*every\s*(\w+)", r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(\w+)\s*of\s*every\s*(\w+)", r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(last|first|second|third|fourth)\s*(\w+)\s*of\s*every\s*(\w+)", ] for pattern in patterns: match = re.search(pattern, user_input, re.IGNORECASE) if match: summary = match.group(2).strip() start_str = match.group(3).strip() end_str = match.group(4).strip() if match.group(5) is None: tomorrow = datetime.date.today() + datetime.timedelta(days=1) start = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(start_str, "%I:%M %p").time()) end = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(end_str, "%I:%M %p").time()) recurring = False elif match.group(6): day_of_week = match.group(6).lower() start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time()) while start.strftime("%A").lower() != day_of_week: start += datetime.timedelta(days=1) end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1])) recurring = (match.group(7) == "every") elif match.group(8): ordinal = match.group(8).lower() weekday = match.group(9).lower() month = match.group(10).lower() start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time()) next_month = start.replace(day=1) + datetime.timedelta(days=32) while start.strftime("%B").lower() != month: start = next_month next_month = start.replace(day=1) + datetime.timedelta(days=32) while start.strftime("%A").lower() != weekday: start += datetime.timedelta(days=1) if ordinal == "last": while start.replace(day=1) + datetime.timedelta(days=32) > start.replace(month=start.month + 1, day=1): start -= datetime.timedelta(days=7) else: count = 1 while count < int(ordinal): start += datetime.timedelta(days=7) if start.strftime("%B").lower() != month: break count += 1 end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1])) recurring = (match.group(11) == "every") start = pytz.utc.localize(start) end = pytz.utc.localize(end) return summary, start, end, recurring # If the input doesn't match any pattern, try to parse it using dateutil try: date_strings = dateutil.parser.parse(user_input, fuzzy=True) if isinstance(date_strings, list): start, end = date_strings else: start = end = date_strings summary = "Event" start = pytz.utc.localize(start) end = pytz.utc.localize(end) return summary, start, end, False except (ValueError, OverflowError): pass return None, None, None, False # Gradio interface def chat(user_input): response = process_input(user_input) return response iface = gr.Interface(chat, inputs="text", outputs="text", title="AI Scheduling Assistant") iface.launch()