SandLogicTechnologies commited on
Commit
d816a8a
1 Parent(s): 3080342

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -37
app.py CHANGED
@@ -8,9 +8,8 @@ import torch
8
  import json
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
-
12
  DESCRIPTION = """\
13
- Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
14
  For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
15
  """
16
 
@@ -20,17 +19,31 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
- model_id = "SandLogicTechnologies/Shakti-2.5B"
24
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_id,
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- token=os.getenv("SHAKTI")
30
-
31
- )
32
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
34
 
35
  @spaces.GPU(duration=90)
36
  def generate(
@@ -79,6 +92,28 @@ def generate(
79
  outputs.append(text)
80
  yield "".join(outputs)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  chat_interface = gr.ChatInterface(
84
  fn=generate,
@@ -97,39 +132,28 @@ chat_interface = gr.ChatInterface(
97
  step=0.1,
98
  value=0.6,
99
  ),
100
- # gr.Slider(
101
- # label="Top-p (nucleus sampling)",
102
- # minimum=0.05,
103
- # maximum=1.0,
104
- # step=0.05,
105
- # value=0.9,
106
- # ),
107
- # gr.Slider(
108
- # label="Top-k",
109
- # minimum=1,
110
- # maximum=1000,
111
- # step=1,
112
- # value=50,
113
- # ),
114
- # gr.Slider(
115
- # label="Repetition penalty",
116
- # minimum=1.0,
117
- # maximum=2.0,
118
- # step=0.05,
119
- # value=1.2,
120
- # ),
121
  ],
122
  stop_btn=None,
123
- examples=[
124
- ["Tell me a story"], ["write a short poem which is hard to sing"], ['मुझे भारतीय इतिहास के बारे में बताएं']
125
- ],
126
  cache_examples=False,
127
  )
128
 
129
  with gr.Blocks(css="style.css", fill_height=True) as demo:
130
  gr.Markdown(DESCRIPTION)
131
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
 
 
 
 
 
 
 
 
 
 
 
 
132
  chat_interface.render()
133
 
134
  if __name__ == "__main__":
135
- demo.queue(max_size=20).launch()
 
8
  import json
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
 
11
  DESCRIPTION = """\
12
+ Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service.
13
  For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
14
  """
15
 
 
19
 
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
22
+ # Model configurations
23
+ model_options = {
24
+ "Shakti-100M": "SandLogicTechnologies/Shakti-100M",
25
+ "Shakti-250M": "SandLogicTechnologies/Shakti-250M",
26
+ "Shakti-2.5B": "SandLogicTechnologies/Shakti-2.5B"
27
+ }
28
+
29
+ # Initialize tokenizer and model variables
30
+ tokenizer = None
31
+ model = None
32
+
33
+ def load_model(selected_model: str):
34
+ global tokenizer, model
35
+ model_id = model_options[selected_model]
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ model_id,
39
+ device_map="auto",
40
+ torch_dtype=torch.bfloat16,
41
+ token=os.getenv("SHAKTI")
42
+ )
43
+ model.eval()
44
 
45
+ # Initial model load (default to 2.5B)
46
+ load_model("Shakti-2.5B")
47
 
48
  @spaces.GPU(duration=90)
49
  def generate(
 
92
  outputs.append(text)
93
  yield "".join(outputs)
94
 
95
+ def update_examples(selected_model):
96
+ if selected_model == "Shakti-100M":
97
+ return [["Tell me a story"],
98
+ ["Write a short poem on Rose"],
99
+ ["What are computers"]]
100
+ elif selected_model == "Shakti-250M":
101
+ return [["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"],
102
+ ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"],
103
+ ["What foods are good for boosting the immune system?"],
104
+ ["What is the difference between a stock and a bond?"],
105
+ ["How can I start saving for retirement?"],
106
+ ["What are some low-risk investment options?"],
107
+ ["What is a power of attorney and when is it used?"],
108
+ ["What are the key differences between a will and a trust?"],
109
+ ["How do I legally protect my business name?"]]
110
+ else:
111
+ return [["Tell me a story"], ["write a short poem which is hard to sing"], ['मुझे भारतीय इतिहास के बारे में बताएं']]
112
+
113
+ def on_model_select(selected_model):
114
+ load_model(selected_model) # Load the selected model
115
+ return update_examples(selected_model) # Return new examples based on the selected model
116
+
117
 
118
  chat_interface = gr.ChatInterface(
119
  fn=generate,
 
132
  step=0.1,
133
  value=0.6,
134
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  ],
136
  stop_btn=None,
137
+ examples=update_examples("Shakti-2.5B"), # Set initial examples for 2.5B model
 
 
138
  cache_examples=False,
139
  )
140
 
141
  with gr.Blocks(css="style.css", fill_height=True) as demo:
142
  gr.Markdown(DESCRIPTION)
143
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
144
+
145
+ # Dropdown for model selection
146
+ model_dropdown = gr.Dropdown(
147
+ label="Select Model",
148
+ choices=["Shakti-100M", "Shakti-250M", "Shakti-2.5B"],
149
+ value="Shakti-2.5B",
150
+ interactive=True,
151
+ )
152
+
153
+ # Function to handle model change and update examples dynamically
154
+ model_dropdown.change(on_model_select, inputs=model_dropdown, outputs=[chat_interface])
155
+
156
  chat_interface.render()
157
 
158
  if __name__ == "__main__":
159
+ demo.queue(max_size=20).launch()