Nekochu commited on
Commit
bee5b00
1 Parent(s): 5c7c7d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
-
5
  import gradio as gr
6
  import spaces
7
  import torch
@@ -11,15 +10,9 @@ MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
- MODELS = {
15
- "Nekochu/Luminia-13B-v3": "Default - Nekochu/Luminia-13B-v3",
16
- "Nekochu/Llama-2-13B-German-ORPO": "German ORPO - Nekochu/Llama-2-13B-German-ORPO",
17
- }
18
-
19
  DESCRIPTION = """\
20
- # Text Generation with Selectable Models
21
-
22
- This Space demonstrates text generation using different models. Choose a model from the dropdown and experience its creative capabilities!
23
  """
24
 
25
  LICENSE = """
@@ -28,25 +21,29 @@ LICENSE = """
28
  """
29
 
30
  if not torch.cuda.is_available():
31
- DESCRIPTION += "\n<p>Running on CPU This demo does not work on CPU.</p>"
 
 
32
 
33
  @spaces.GPU(duration=120)
34
  def generate(
 
35
  message: str,
36
  chat_history: list[tuple[str, str]],
37
  system_prompt: str,
38
- model_id: str = None, # Add default value for model_id
39
  max_new_tokens: int = 1024,
40
  temperature: float = 0.6,
41
  top_p: float = 0.9,
42
  top_k: int = 50,
43
  repetition_penalty: float = 1.2,
44
  ) -> Iterator[str]:
45
- if not model_id:
46
- raise ValueError("Please select a model from the dropdown.")
47
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
48
- tokenizer = AutoTokenizer.from_pretrained(model_id)
49
- tokenizer.use_default_system_prompt = False
 
 
50
 
51
  conversation = []
52
  if system_prompt:
@@ -82,19 +79,17 @@ def generate(
82
  yield "".join(outputs)
83
 
84
 
85
- model_dropdown = gr.Dropdown(label="Select Model", choices=list(MODELS.values()))
86
-
87
  chat_interface = gr.ChatInterface(
88
  fn=generate,
89
  additional_inputs=[
90
- model_dropdown,
91
  gr.Textbox(label="System prompt", lines=6),
92
  gr.Slider(
93
  label="Max new tokens",
94
  minimum=1,
95
  maximum=MAX_MAX_NEW_TOKENS,
96
  step=1,
97
- value=DEFAULT_MAX
98
  ),
99
  gr.Slider(
100
  label="Temperature",
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
4
  import gradio as gr
5
  import spaces
6
  import torch
 
10
  DEFAULT_MAX_NEW_TOKENS = 1024
11
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
12
 
 
 
 
 
 
13
  DESCRIPTION = """\
14
+ # Nekochu/Luminia-13B-v3
15
+ This Space demonstrates model Nekochu/Luminia-13B-v3 by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
 
16
  """
17
 
18
  LICENSE = """
 
21
  """
22
 
23
  if not torch.cuda.is_available():
24
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
25
+
26
+ models_cache = {}
27
 
28
  @spaces.GPU(duration=120)
29
  def generate(
30
+ model_id: str,
31
  message: str,
32
  chat_history: list[tuple[str, str]],
33
  system_prompt: str,
 
34
  max_new_tokens: int = 1024,
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
37
  top_k: int = 50,
38
  repetition_penalty: float = 1.2,
39
  ) -> Iterator[str]:
40
+ if model_id not in models_cache:
41
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
42
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
43
+ tokenizer.use_default_system_prompt = False
44
+ models_cache[model_id] = (model, tokenizer)
45
+ else:
46
+ model, tokenizer = models_cache[model_id]
47
 
48
  conversation = []
49
  if system_prompt:
 
79
  yield "".join(outputs)
80
 
81
 
 
 
82
  chat_interface = gr.ChatInterface(
83
  fn=generate,
84
  additional_inputs=[
85
+ gr.Textbox(label="Model ID", placeholder="Nekochu/Luminia-13B-v3"),
86
  gr.Textbox(label="System prompt", lines=6),
87
  gr.Slider(
88
  label="Max new tokens",
89
  minimum=1,
90
  maximum=MAX_MAX_NEW_TOKENS,
91
  step=1,
92
+ value=DEFAULT_MAX_NEW_TOKENS,
93
  ),
94
  gr.Slider(
95
  label="Temperature",