File size: 8,059 Bytes
994d7ce
30645dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9150b81
30645dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994d7ce
30645dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import datetime
import re
import os
import pytz
import dateutil.parser

# Load the DistilBERT model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Initialize an empty list to store events
events = []

# Load events from file if it exists
if os.path.isfile("events.txt"):
    with open("events.txt", "r") as f:
        for line in f:
            event_data = line.strip().split("|")
            if len(event_data) == 4:
                name, start_str, end_str, recurring = event_data
                start = dateutil.parser.parse(start_str)
                end = dateutil.parser.parse(end_str)
                is_recurring = (recurring.lower() == "true")
                events.append({"name": name, "start": start, "end": end, "recurring": is_recurring})
                print(f"Loaded event: {name} ({start} - {end})")

def generate_response(prompt):
    """
    Generate a response using the DistilBERT model.
    """
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model(**inputs)[0]  # get the logits
    return tokenizer.decode(torch.argmax(output, dim=-1)[0], skip_special_tokens=True)

def list_events(start, end):
    """
    List events for the day between start and end times.
    """
    event_summaries = []
    for event in events:
        event_start = event["start"]
        event_end = event["end"]
        if event_start.tzinfo is None:
            event_start = pytz.utc.localize(event_start)
        if event_end.tzinfo is None:
            event_end = pytz.utc.localize(event_end)
        if start <= event_start < end:
            event_summaries.append(f"{event['name']} ({event_start.strftime('%I:%M %p')} - {event_end.strftime('%I:%M %p')})")
    if not event_summaries:
        return "There are no events presently."
    return ", ".join(event_summaries)

def create_event(summary, start, end, recurring=False):
    """
    Create a new event.
    """
    event = {"name": summary, "start": start, "end": end, "recurring": recurring}
    events.append(event)
    save_events()
    return f"Event '{summary}' has been scheduled from {start.strftime('%I:%M %p')} to {end.strftime('%I:%M %p')}."

def save_events():
    """
    Save events to a text file.
    """
    with open("events.txt", "w") as f:
        for event in events:
            start_str = event["start"].strftime("%Y-%m-%d %H:%M:%S")
            end_str = event["end"].strftime("%Y-%m-%d %H:%M:%S")
            recurring_str = "True" if event["recurring"] else "False"
            f.write(f"{event['name']}|{start_str}|{end_str}|{recurring_str}\n")

def process_input(user_input):
    """
    Process the user input and perform the corresponding action.
    """
    if any(keyword in user_input.lower() for keyword in ["schedule", "create"]):
        summary, start, end, recurring = extract_event_details(user_input)
        if summary and start and end:
            response = create_event(summary, start, end, recurring)
            return response
        else:
            return "I'm sorry, I couldn't understand the event details. Please try again."

    elif any(keyword in user_input.lower() for keyword in ["list", "show"]):
        start = datetime.datetime.now(pytz.utc).replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=pytz.utc)
        end = start + datetime.timedelta(days=1, seconds=-1, microseconds=-1)
        existing_events = list_events(start, end)
        return existing_events

    else:
        return "I'm sorry, I didn't understand your request. Please try again."

def extract_event_details(user_input):
    """
    Extract the event summary, start time, end time, and recurrence from the user input.
    """
    patterns = [
        r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*tomorrow",
        r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*(\w+)",
        r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*every\s*(\w+)",
        r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(\w+)\s*of\s*every\s*(\w+)",
        r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(last|first|second|third|fourth)\s*(\w+)\s*of\s*every\s*(\w+)",
    ]
    for pattern in patterns:
        match = re.search(pattern, user_input, re.IGNORECASE)
        if match:
            summary = match.group(2).strip()
            start_str = match.group(3).strip()
            end_str = match.group(4).strip()
            if match.group(5) is None:
                tomorrow = datetime.date.today() + datetime.timedelta(days=1)
                start = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(start_str, "%I:%M %p").time())
                end = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(end_str, "%I:%M %p").time())
                recurring = False
            elif match.group(6):
                day_of_week = match.group(6).lower()
                start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time())
                while start.strftime("%A").lower() != day_of_week:
                    start += datetime.timedelta(days=1)
                end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1]))
                recurring = (match.group(7) == "every")
            elif match.group(8):
                ordinal = match.group(8).lower()
                weekday = match.group(9).lower()
                month = match.group(10).lower()
                start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time())
                next_month = start.replace(day=1) + datetime.timedelta(days=32)
                while start.strftime("%B").lower() != month:
                    start = next_month
                    next_month = start.replace(day=1) + datetime.timedelta(days=32)
                while start.strftime("%A").lower() != weekday:
                    start += datetime.timedelta(days=1)
                if ordinal == "last":
                    while start.replace(day=1) + datetime.timedelta(days=32) > start.replace(month=start.month + 1, day=1):
                        start -= datetime.timedelta(days=7)
                else:
                    count = 1
                    while count < int(ordinal):
                        start += datetime.timedelta(days=7)
                        if start.strftime("%B").lower() != month:
                            break
                        count += 1
                end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1]))
                recurring = (match.group(11) == "every")
            start = pytz.utc.localize(start)
            end = pytz.utc.localize(end)
            return summary, start, end, recurring

    # If the input doesn't match any pattern, try to parse it using dateutil
    try:
        date_strings = dateutil.parser.parse(user_input, fuzzy=True)
        if isinstance(date_strings, list):
            start, end = date_strings
        else:
            start = end = date_strings
        summary = "Event"
        start = pytz.utc.localize(start)
        end = pytz.utc.localize(end)
        return summary, start, end, False
    except (ValueError, OverflowError):
        pass

    return None, None, None, False

# Gradio interface
def chat(user_input):
    response = process_input(user_input)
    return response

iface = gr.Interface(chat, inputs="text", outputs="text", title="AI Scheduling Assistant")
iface.launch()