Tijmen2 commited on
Commit
fec3834
·
verified ·
1 Parent(s): 9780084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -40
app.py CHANGED
@@ -1,17 +1,20 @@
1
  import spaces
2
  import gradio as gr
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4
- import torch
5
  import random
6
 
7
- model_name = "AstroMLab/AstroSage-8B"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- streamer = TextStreamer(tokenizer)
10
- # Load the model with 8-bit quantization using bitsandbytes
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_name,
13
- torch_dtype=torch.bfloat16,
14
- load_in_8bit=True,
 
 
 
15
  )
16
 
17
  # Placeholder responses for when context is empty
@@ -30,43 +33,42 @@ def user(user_message, history):
30
 
31
  @spaces.GPU(duration=20)
32
  def bot(history):
33
- """Generate the chatbot response."""
34
-
35
  if not history:
36
  history = []
 
 
 
 
 
 
 
 
37
 
38
- # Prepare input prompt for the model
39
- system_prompt = (
40
- "You are AstroSage, an intelligent AI assistant specializing in astronomy, astrophysics, and cosmology. "
41
- "Provide accurate, scientific information while making complex concepts accessible. "
42
- "You're enthusiastic about space exploration and maintain a sense of wonder about the cosmos."
43
- )
44
 
45
- # Construct the chat history as a single input string
46
- prompt = system_prompt + "\n\n"
47
- for message in history:
48
- if message["role"] == "user":
49
- prompt += f"User: {message['content']}\n"
50
- else:
51
- prompt += f"AstroSage: {message['content']}\n"
52
- prompt += "AstroSage: "
53
-
54
- # Generate response
55
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
56
- outputs = model.generate(
57
- **inputs,
58
- max_new_tokens=512,
59
  temperature=0.7,
60
  top_p=0.95,
61
- do_sample=True,
62
- streamer=streamer
63
  )
64
-
65
- # Decode the generated output and update history
66
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
67
- response_text = response_text[len(prompt):].strip()
68
- history.append({"role": "assistant", "content": response_text})
69
- yield history
70
 
71
  def initial_greeting():
72
  """Return properly formatted initial greeting."""
 
1
  import spaces
2
  import gradio as gr
3
+ from llama_cpp import Llama
4
+ from huggingface_hub import hf_hub_download
5
  import random
6
 
7
+ model_path = hf_hub_download(
8
+ repo_id="AstroMLab/AstroSage-8B-GGUF",
9
+ filename="AstroSage-8B-Q8_0.gguf"
10
+ )
11
+
12
+ llm = Llama(
13
+ model_path=model_path,
14
+ n_ctx=2048,
15
+ chat_format="llama-3",
16
+ n_gpu_layers=-1, # ensure all layers are on GPU
17
+ flash_attn=True,
18
  )
19
 
20
  # Placeholder responses for when context is empty
 
33
 
34
  @spaces.GPU(duration=20)
35
  def bot(history):
36
+ """Yield the chatbot response for streaming."""
37
+
38
  if not history:
39
  history = []
40
+
41
+ # Prepare the messages for the model
42
+ messages = [
43
+ {
44
+ "role": "system",
45
+ "content": "You are AstroSage, an intelligent AI assistant specializing in astronomy, astrophysics, and cosmology. Provide accurate, scientific information while making complex concepts accessible. You're enthusiastic about space exploration and maintain a sense of wonder about the cosmos."
46
+ }
47
+ ]
48
 
49
+ # Add chat history
50
+ for message in history[:-1]: # Exclude the last message which we just added
51
+ messages.append({"role": message["role"], "content": message["content"]})
 
 
 
52
 
53
+ # Add the current user message
54
+ messages.append({"role": "user", "content": history[-1]["content"]})
55
+
56
+ # Start generating the response
57
+ history.append({"role": "assistant", "content": ""})
58
+
59
+ # Stream the response
60
+ response = llm.create_chat_completion(
61
+ messages=messages,
62
+ max_tokens=512,
 
 
 
 
63
  temperature=0.7,
64
  top_p=0.95,
65
+ stream=True,
 
66
  )
67
+
68
+ for chunk in response:
69
+ if chunk and "content" in chunk["choices"][0]["delta"]:
70
+ history[-1]["content"] += chunk["choices"][0]["delta"]["content"]
71
+ yield history
 
72
 
73
  def initial_greeting():
74
  """Return properly formatted initial greeting."""