AstroSage / app.py
Tijmen2's picture
Update app.py
d159929 verified
raw
history blame
6.96 kB
from threading import Thread
import gradio as gr
import random
import torch
import spaces
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
TextIteratorStreamer
)
# Constants for the model and configuration
MODEL_ID = "AstroMLab/AstroSage-8B"
WINDOW_SIZE = 2048
DEVICE = "cuda"
# Load model configuration, tokenizer, and model
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=MODEL_ID,
config=config,
device_map="auto",
use_safetensors=True,
trust_remote_code=True,
load_in_4bit=True,
torch_dtype=torch.bfloat16
)
# Placeholder responses for when context is empty
GREETING_MESSAGES = [
"Greetings! I am AstroSage, your guide to the cosmos. What would you like to explore today?",
"Welcome to our cosmic journey! I am AstroSage. How may I assist you in understanding the universe?",
"AstroSage here. Ready to explore the mysteries of space and time. How may I be of assistance?",
"The universe awaits! I'm AstroSage. What astronomical wonders shall we discuss?",
]
def format_message(role: str, content: str) -> str:
"""Format a single message according to Llama-3 chat template."""
return f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
def generate_text(prompt: str, history: list, max_new_tokens=512, temperature=0.7, top_p=0.95):
"""
Generate a response using the transformer model with proper Llama-3 chat formatting.
"""
# Start with begin_of_text token
formatted_messages = ["<|begin_of_text|>"]
# Add formatted history
for msg in history:
formatted_message = format_message(msg['role'], msg['content'])
formatted_messages.append(formatted_message)
# Add the current prompt
formatted_message = format_message('user', prompt)
formatted_messages.append(formatted_message)
# Add the start of assistant's response
formatted_messages.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
# Combine all messages
prompt_with_history = "\n".join(formatted_messages)
# Encode the prompt
inputs = tokenizer([prompt_with_history], return_tensors="pt", truncation=True).to(DEVICE)
input_length = inputs["input_ids"].shape[-1]
max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length)
# Prepare text streamer for live updates
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=10.0,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
# Generate the response in a separate thread for streaming
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Collect and return the response
response = ""
for new_text in streamer:
response += new_text
yield response
def user(user_message, history):
"""
Add the user's message to the history.
"""
if history is None:
history = []
return "", history + [{"role": "user", "content": user_message}]
@spaces.GPU
def bot(history):
"""
Generate the bot's response based on the history.
"""
if not history:
history = [{"role": "assistant", "content": random.choice(GREETING_MESSAGES)}]
last_user_message = history[-1]["content"] if history else ""
response_generator = generate_text(last_user_message, history)
history.append({"role": "assistant", "content": ""})
# Stream the response back
for partial_response in response_generator:
history[-1]["content"] = partial_response
yield history
def initial_greeting():
"""
Return the initial greeting message.
"""
return [
{"role": "system","content": "You are AstroSage, an intelligent AI assistant specializing in astronomy, astrophysics, and cosmology. Provide accurate, scientific information while making complex concepts accessible. You're enthusiastic about space exploration and maintain a sense of wonder about the cosmos. Start by introducing yourself."},
{"role": "assistant", "content": random.choice(GREETING_MESSAGES)}
]
# Custom CSS for a space theme
custom_css = """
#component-0 {
background-color: #1a1a2e;
border-radius: 15px;
padding: 20px;
}
.dark {
background-color: #0f0f1a;
}
.contain {
max-width: 1200px !important;
}
"""
# Create the Gradio interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate")) as demo:
gr.Markdown(
"""
# 🌌 AstroSage: Your Cosmic AI Companion
Welcome to AstroSage, an advanced AI assistant specializing in astronomy, astrophysics, and cosmology.
Powered by the AstroSage-Llama-3.1-8B model, I'm here to help you explore the wonders of the universe!
### What Can I Help You With?
- πŸͺ Explanations of astronomical phenomena
- πŸš€ Space exploration and missions
- ⭐ Stars, galaxies, and cosmology
- 🌍 Planetary science and exoplanets
- πŸ“Š Astrophysics concepts and theories
- πŸ”­ Astronomical instruments and observations
Just type your question below and let's embark on a cosmic journey together!
"""
)
chatbot = gr.Chatbot(
label="Chat with AstroSage",
bubble_full_width=False,
show_label=True,
height=450,
type="messages"
)
with gr.Row():
msg = gr.Textbox(
label="Type your message here",
placeholder="Ask me anything about space and astronomy...",
scale=9
)
clear = gr.Button("Clear Chat", scale=1)
# Example questions for quick start
gr.Examples(
examples=[
"What is a black hole and how does it form?",
"Can you explain the life cycle of a star?",
"What are exoplanets and how do we detect them?",
"Tell me about the James Webb Space Telescope.",
"What is dark matter and why is it important?"
],
inputs=msg,
label="Example Questions"
)
# Set up the message chain with streaming
msg.submit(
user,
[msg, chatbot],
[msg, chatbot],
queue=False
).then(
bot,
chatbot,
chatbot
)
# Clear button functionality
clear.click(lambda: None, None, chatbot, queue=False)
# Initial greeting
demo.load(initial_greeting, None, chatbot, queue=False)
# Launch the app
if __name__ == "__main__":
demo.launch()