Spaces:
Build error
Build error
File size: 8,059 Bytes
994d7ce 30645dd 9150b81 30645dd 994d7ce 30645dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import datetime
import re
import os
import pytz
import dateutil.parser
# Load the DistilBERT model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Initialize an empty list to store events
events = []
# Load events from file if it exists
if os.path.isfile("events.txt"):
with open("events.txt", "r") as f:
for line in f:
event_data = line.strip().split("|")
if len(event_data) == 4:
name, start_str, end_str, recurring = event_data
start = dateutil.parser.parse(start_str)
end = dateutil.parser.parse(end_str)
is_recurring = (recurring.lower() == "true")
events.append({"name": name, "start": start, "end": end, "recurring": is_recurring})
print(f"Loaded event: {name} ({start} - {end})")
def generate_response(prompt):
"""
Generate a response using the DistilBERT model.
"""
inputs = tokenizer(prompt, return_tensors="pt")
output = model(**inputs)[0] # get the logits
return tokenizer.decode(torch.argmax(output, dim=-1)[0], skip_special_tokens=True)
def list_events(start, end):
"""
List events for the day between start and end times.
"""
event_summaries = []
for event in events:
event_start = event["start"]
event_end = event["end"]
if event_start.tzinfo is None:
event_start = pytz.utc.localize(event_start)
if event_end.tzinfo is None:
event_end = pytz.utc.localize(event_end)
if start <= event_start < end:
event_summaries.append(f"{event['name']} ({event_start.strftime('%I:%M %p')} - {event_end.strftime('%I:%M %p')})")
if not event_summaries:
return "There are no events presently."
return ", ".join(event_summaries)
def create_event(summary, start, end, recurring=False):
"""
Create a new event.
"""
event = {"name": summary, "start": start, "end": end, "recurring": recurring}
events.append(event)
save_events()
return f"Event '{summary}' has been scheduled from {start.strftime('%I:%M %p')} to {end.strftime('%I:%M %p')}."
def save_events():
"""
Save events to a text file.
"""
with open("events.txt", "w") as f:
for event in events:
start_str = event["start"].strftime("%Y-%m-%d %H:%M:%S")
end_str = event["end"].strftime("%Y-%m-%d %H:%M:%S")
recurring_str = "True" if event["recurring"] else "False"
f.write(f"{event['name']}|{start_str}|{end_str}|{recurring_str}\n")
def process_input(user_input):
"""
Process the user input and perform the corresponding action.
"""
if any(keyword in user_input.lower() for keyword in ["schedule", "create"]):
summary, start, end, recurring = extract_event_details(user_input)
if summary and start and end:
response = create_event(summary, start, end, recurring)
return response
else:
return "I'm sorry, I couldn't understand the event details. Please try again."
elif any(keyword in user_input.lower() for keyword in ["list", "show"]):
start = datetime.datetime.now(pytz.utc).replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=pytz.utc)
end = start + datetime.timedelta(days=1, seconds=-1, microseconds=-1)
existing_events = list_events(start, end)
return existing_events
else:
return "I'm sorry, I didn't understand your request. Please try again."
def extract_event_details(user_input):
"""
Extract the event summary, start time, end time, and recurrence from the user input.
"""
patterns = [
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*tomorrow",
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*(\w+)",
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*every\s*(\w+)",
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(\w+)\s*of\s*every\s*(\w+)",
r"(schedule|create)(.*?)from\s*(\d+:\d+\s*[aApP][mM])\s*to\s*(\d+:\d+\s*[aApP][mM])\s*on\s*the\s*(last|first|second|third|fourth)\s*(\w+)\s*of\s*every\s*(\w+)",
]
for pattern in patterns:
match = re.search(pattern, user_input, re.IGNORECASE)
if match:
summary = match.group(2).strip()
start_str = match.group(3).strip()
end_str = match.group(4).strip()
if match.group(5) is None:
tomorrow = datetime.date.today() + datetime.timedelta(days=1)
start = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(start_str, "%I:%M %p").time())
end = datetime.datetime.combine(tomorrow, datetime.datetime.strptime(end_str, "%I:%M %p").time())
recurring = False
elif match.group(6):
day_of_week = match.group(6).lower()
start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time())
while start.strftime("%A").lower() != day_of_week:
start += datetime.timedelta(days=1)
end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1]))
recurring = (match.group(7) == "every")
elif match.group(8):
ordinal = match.group(8).lower()
weekday = match.group(9).lower()
month = match.group(10).lower()
start = datetime.datetime.combine(datetime.date.today(), datetime.datetime.strptime(start_str, "%I:%M %p").time())
next_month = start.replace(day=1) + datetime.timedelta(days=32)
while start.strftime("%B").lower() != month:
start = next_month
next_month = start.replace(day=1) + datetime.timedelta(days=32)
while start.strftime("%A").lower() != weekday:
start += datetime.timedelta(days=1)
if ordinal == "last":
while start.replace(day=1) + datetime.timedelta(days=32) > start.replace(month=start.month + 1, day=1):
start -= datetime.timedelta(days=7)
else:
count = 1
while count < int(ordinal):
start += datetime.timedelta(days=7)
if start.strftime("%B").lower() != month:
break
count += 1
end = start + datetime.timedelta(hours=int(end_str.split(":")[0]) - int(start_str.split(":")[0]), minutes=int(end_str.split(":")[1]) - int(start_str.split(":")[1]))
recurring = (match.group(11) == "every")
start = pytz.utc.localize(start)
end = pytz.utc.localize(end)
return summary, start, end, recurring
# If the input doesn't match any pattern, try to parse it using dateutil
try:
date_strings = dateutil.parser.parse(user_input, fuzzy=True)
if isinstance(date_strings, list):
start, end = date_strings
else:
start = end = date_strings
summary = "Event"
start = pytz.utc.localize(start)
end = pytz.utc.localize(end)
return summary, start, end, False
except (ValueError, OverflowError):
pass
return None, None, None, False
# Gradio interface
def chat(user_input):
response = process_input(user_input)
return response
iface = gr.Interface(chat, inputs="text", outputs="text", title="AI Scheduling Assistant")
iface.launch() |