Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import re | |
from datetime import datetime, timedelta | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
# Load your model | |
model = AutoModelForCausalLM.from_pretrained("natechenette/weather-model-merged") | |
tokenizer = AutoTokenizer.from_pretrained("natechenette/weather-model-merged") | |
# Setup OpenAI (you'll need to set OPENAI_API_KEY environment variable) | |
try: | |
from openai import OpenAI | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
OPENAI_AVAILABLE = True | |
except: | |
OPENAI_AVAILABLE = False | |
def extract_location(text): | |
"""Better location extraction using regex patterns""" | |
patterns = [ | |
# More specific patterns that stop at common words | |
r'in\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*?)(?:\s+(?:this|last|next|current|today|tomorrow|week|month|year|forecast|weather|data|information))?', | |
r'for\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*?)(?:\s+(?:this|last|next|current|today|tomorrow|week|month|year|forecast|weather|data|information))?', | |
r'at\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*?)(?:\s+(?:this|last|next|current|today|tomorrow|week|month|year|forecast|weather|data|information))?', | |
r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*?)(?:\s+(?:this|last|next|current|today|tomorrow|week|month|year|forecast|weather|data|information))?' | |
] | |
for pattern in patterns: | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
location = match.group(1).strip() | |
# Clean up common words that might be captured | |
location = re.sub(r'\b(this|last|next|current|today|tomorrow|week|month|year|forecast|weather|data|information)\b', '', location, flags=re.IGNORECASE).strip() | |
if location and len(location) > 1: | |
return location | |
return None | |
def extract_date_info(text): | |
"""Extract date information from user query""" | |
text_lower = text.lower() | |
# Check for specific date patterns | |
date_patterns = [ | |
r'(\d{1,2})/(\d{1,2})/(\d{4})', # MM/DD/YYYY | |
r'(\d{4})-(\d{1,2})-(\d{1,2})', # YYYY-MM-DD | |
r'(\w+)\s+(\d{1,2}),?\s+(\d{4})', # Month DD, YYYY | |
r'(\d{1,2})\s+(\w+)\s+(\d{4})', # DD Month YYYY | |
] | |
for pattern in date_patterns: | |
match = re.search(pattern, text_lower) | |
if match: | |
try: | |
if '/' in pattern: | |
month, day, year = match.groups() | |
return datetime(int(year), int(month), int(day)) | |
elif '-' in pattern: | |
year, month, day = match.groups() | |
return datetime(int(year), int(month), int(day)) | |
else: | |
# Handle text month names | |
month_names = { | |
'january': 1, 'february': 2, 'march': 3, 'april': 4, | |
'may': 5, 'june': 6, 'july': 7, 'august': 8, | |
'september': 9, 'october': 10, 'november': 11, 'december': 12 | |
} | |
if pattern == r'(\w+)\s+(\d{1,2}),?\s+(\d{4})': | |
month_name, day, year = match.groups() | |
month = month_names.get(month_name.lower()) | |
if month: | |
return datetime(int(year), month, int(day)) | |
else: | |
day, month_name, year = match.groups() | |
month = month_names.get(month_name.lower()) | |
if month: | |
return datetime(int(year), month, int(day)) | |
except ValueError: | |
continue | |
# Check for relative time patterns | |
if any(word in text_lower for word in ['last 7 days', 'past week', 'last week']): | |
return 'last_week' | |
elif any(word in text_lower for word in ['yesterday']): | |
return datetime.now() - timedelta(days=1) | |
elif any(word in text_lower for word in ['today']): | |
return datetime.now() | |
return None | |
def get_weather_data(location, query_type="current", specific_date=None): | |
"""Get weather data from Open-Meteo API based on query type""" | |
try: | |
# Geocoding | |
geo_url = f"https://geocoding-api.open-meteo.com/v1/search?name={location}&count=1" | |
geo_response = requests.get(geo_url) | |
geo_data = geo_response.json() | |
if not geo_data.get('results'): | |
return f"Location '{location}' not found. Please try a different city name." | |
lat = geo_data['results'][0]['latitude'] | |
lon = geo_data['results'][0]['longitude'] | |
# Determine what data to fetch based on query type and specific date | |
if specific_date: | |
if isinstance(specific_date, datetime): | |
# Specific date requested | |
date_str = specific_date.strftime("%Y-%m-%d") | |
weather_url = f"https://archive-api.open-meteo.com/v1/archive?latitude={lat}&longitude={lon}&start_date={date_str}&end_date={date_str}&daily=temperature_2m_max,temperature_2m_min,weather_code" | |
else: | |
# Relative time like "last week" | |
end_date = datetime.now().strftime("%Y-%m-%d") | |
start_date = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") | |
weather_url = f"https://archive-api.open-meteo.com/v1/archive?latitude={lat}&longitude={lon}&start_date={start_date}&end_date={end_date}&daily=temperature_2m_max,temperature_2m_min,weather_code" | |
elif "forecast" in query_type.lower() or "tomorrow" in query_type.lower() or "next" in query_type.lower(): | |
# Get 7-day forecast | |
weather_url = f"https://api.open-meteo.com/v1/forecast?latitude={lat}&longitude={lon}&daily=temperature_2m_max,temperature_2m_min,weather_code&timezone=auto" | |
elif "historical" in query_type.lower() or "last week" in query_type.lower() or "yesterday" in query_type.lower() or "was" in query_type.lower(): | |
# Get historical data for last 7 days | |
end_date = datetime.now().strftime("%Y-%m-%d") | |
start_date = (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") | |
weather_url = f"https://archive-api.open-meteo.com/v1/archive?latitude={lat}&longitude={lon}&start_date={start_date}&end_date={end_date}&daily=temperature_2m_max,temperature_2m_min,weather_code" | |
else: | |
# Get current weather | |
weather_url = f"https://api.open-meteo.com/v1/forecast?latitude={lat}&longitude={lon}¤t=temperature_2m,weather_code&timezone=auto" | |
weather_response = requests.get(weather_url) | |
weather_data = weather_response.json() | |
return weather_data | |
except Exception as e: | |
return f"Error fetching weather data: {str(e)}" | |
def get_weather_description(code): | |
"""Convert weather codes to descriptions""" | |
descriptions = { | |
0: "clear sky", | |
1: "mainly clear", | |
2: "partly cloudy", | |
3: "overcast", | |
45: "foggy", | |
48: "depositing rime fog", | |
51: "light drizzle", | |
53: "moderate drizzle", | |
55: "dense drizzle", | |
61: "slight rain", | |
63: "moderate rain", | |
65: "heavy rain", | |
71: "slight snow", | |
73: "moderate snow", | |
75: "heavy snow", | |
95: "thunderstorm" | |
} | |
return descriptions.get(code, "unknown weather conditions") | |
def format_weather_response(location, weather_data, query_type="current", specific_date=None): | |
"""Format weather data into natural language response""" | |
if isinstance(weather_data, str): | |
return weather_data # Return error message as is | |
try: | |
if "current" in weather_data: | |
# Current weather | |
current = weather_data["current"] | |
temp = current["temperature_2m"] | |
weather_code = current["weather_code"] | |
weather_desc = get_weather_description(weather_code) | |
return f"The current weather in {location} is {temp}°C with {weather_desc}." | |
elif "daily" in weather_data: | |
daily = weather_data["daily"] | |
# Check if we have the required data | |
if not all(key in daily for key in ["time", "temperature_2m_max", "temperature_2m_min", "weather_code"]): | |
return f"I received weather data for {location}, but it's missing some information. Please try again." | |
times = daily["time"] | |
temps_max = daily["temperature_2m_max"] | |
temps_min = daily["temperature_2m_min"] | |
weather_codes = daily["weather_code"] | |
if specific_date and isinstance(specific_date, datetime): | |
# Specific date requested | |
date_str = specific_date.strftime("%Y-%m-%d") | |
if date_str in times: | |
idx = times.index(date_str) | |
max_temp = temps_max[idx] | |
min_temp = temps_min[idx] | |
weather_desc = get_weather_description(weather_codes[idx]) | |
formatted_date = specific_date.strftime("%A, %B %d, %Y") | |
return f"On {formatted_date}, the weather in {location} had a high of {max_temp}°C, low of {min_temp}°C with {weather_desc}." | |
else: | |
return f"I couldn't find weather data for {location} on {date_str}." | |
if "forecast" in query_type.lower() or "tomorrow" in query_type.lower(): | |
# Format as forecast | |
response = f"Here's the weather forecast for {location}:\n\n" | |
for i in range(min(7, len(times))): | |
try: | |
date = datetime.strptime(times[i], "%Y-%m-%d").strftime("%A, %B %d") | |
max_temp = temps_max[i] | |
min_temp = temps_min[i] | |
weather_desc = get_weather_description(weather_codes[i]) | |
response += f"{date}: High of {max_temp}°C, low of {min_temp}°C with {weather_desc}.\n" | |
except (ValueError, IndexError) as e: | |
continue | |
return response | |
else: | |
# Format as historical data | |
if specific_date == 'last_week': | |
response = f"Here's the weather history for {location} over the past week:\n\n" | |
else: | |
response = f"Here's the weather history for {location}:\n\n" | |
for i in range(min(7, len(times))): | |
try: | |
date = datetime.strptime(times[i], "%Y-%m-%d").strftime("%A, %B %d") | |
max_temp = temps_max[i] | |
min_temp = temps_min[i] | |
weather_desc = get_weather_description(weather_codes[i]) | |
response += f"{date}: High of {max_temp}°C, low of {min_temp}°C with {weather_desc}.\n" | |
except (ValueError, IndexError) as e: | |
continue | |
return response | |
return f"I found weather data for {location}, but I'm not sure how to format it for your query." | |
except Exception as e: | |
return f"Sorry, I had trouble processing the weather data for {location}. Error: {str(e)}" | |
def get_chatgpt_response(user_input): | |
"""Get response from ChatGPT for comparison""" | |
if not OPENAI_AVAILABLE: | |
return "ChatGPT comparison not available (API key not set)" | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful weather assistant. Provide natural, conversational responses about weather."}, | |
{"role": "user", "content": user_input} | |
], | |
max_tokens=200, | |
temperature=0.7 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"Error getting ChatGPT response: {str(e)}" | |
def chat_with_weather(message, history): | |
"""Main chat function""" | |
# Extract location from user message | |
location = extract_location(message) | |
if not location: | |
return "I couldn't identify a location in your message. Please try asking about a specific city or location." | |
# Extract date information | |
specific_date = extract_date_info(message) | |
# Determine query type | |
query_type = "current" | |
if specific_date: | |
query_type = "historical" | |
elif any(word in message.lower() for word in ["forecast", "tomorrow", "next week"]): | |
query_type = "forecast" | |
elif any(word in message.lower() for word in ["historical", "last week", "yesterday", "was"]): | |
query_type = "historical" | |
# Get weather data | |
weather_data = get_weather_data(location, query_type, specific_date) | |
# Format response in natural language | |
weather_response = format_weather_response(location, weather_data, query_type, specific_date) | |
# Get ChatGPT response for comparison | |
chatgpt_response = get_chatgpt_response(message) | |
return weather_response, chatgpt_response | |
# Create the Gradio interface | |
with gr.Blocks(title="Weather Model Comparison") as demo: | |
gr.Markdown("# nateAI - Weather") | |
gr.Markdown("Compare responses from nateAI to ChatGPT") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### nateAI") | |
weather_output = gr.Textbox(label="nateAI Response", lines=5) | |
with gr.Column(): | |
gr.Markdown("### gpt-4o-mini") | |
chatgpt_output = gr.Textbox(label="ChatGPT Response", lines=5) | |
with gr.Row(): | |
input_text = gr.Textbox(label="Ask about weather", placeholder="What's the weather like in Tokyo?") | |
submit_btn = gr.Button("Get Weather", variant="primary") | |
# Add examples | |
gr.Markdown("### Try these examples:") | |
examples = [ | |
"What's the current weather in New York?", | |
"Weather forecast for London", | |
"What was the weather like in Tokyo last week?", | |
"Weather in Paris on June 15, 2024", | |
"What's the weather forecast for Sydney tomorrow?", | |
"Historical weather for Berlin in the last 7 days" | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=input_text, | |
label="Examples" | |
) | |
# Handle the chat | |
def process_message(message): | |
weather_resp, chatgpt_resp = chat_with_weather(message, []) | |
return weather_resp, chatgpt_resp | |
submit_btn.click( | |
fn=process_message, | |
inputs=[input_text], | |
outputs=[weather_output, chatgpt_output] | |
) | |
input_text.submit( | |
fn=process_message, | |
inputs=[input_text], | |
outputs=[weather_output, chatgpt_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |