Artples commited on
Commit
b241b47
1 Parent(s): d05908c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -17
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -19,15 +19,22 @@ This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-m
19
  if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
21
 
22
- model_options = {
 
23
  "Fast-Model": "Artples/L-MChat-Small",
24
  "Quality-Model": "Artples/L-MChat-7b"
25
  }
26
 
 
 
 
 
 
 
27
  @spaces.GPU(enable_queue=True, duration=90)
28
  def generate(
29
- message: str,
30
  model_choice: str,
 
31
  chat_history: list[tuple[str, str]],
32
  system_prompt: str,
33
  max_new_tokens: int = 1024,
@@ -36,23 +43,97 @@ def generate(
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
  ) -> Iterator[str]:
39
- # Your existing function implementation...
40
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- chat_interface = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  fn=generate,
44
- inputs=[
45
- gr.Textbox(lines=2, placeholder="Type your message here..."),
46
- gr.Dropdown(label="Choose Model", choices=list(model_options.keys())),
47
- chat_history, # Updated to include state without label
48
- gr.Textbox(label="System Prompt", lines=6, placeholder="Enter system prompt if any..."),
49
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
50
- # More inputs as previously defined
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ],
52
- outputs=[gr.Textbox(label="Response")],
53
- theme="default",
54
- description=DESCRIPTION
55
  )
56
 
 
 
 
 
57
  if __name__ == "__main__":
58
- chat_interface.launch()
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
 
19
  if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
21
 
22
+ # Dictionary to manage model details
23
+ model_details = {
24
  "Fast-Model": "Artples/L-MChat-Small",
25
  "Quality-Model": "Artples/L-MChat-7b"
26
  }
27
 
28
+ # Initialize models and tokenizers based on availability
29
+ models = {name: AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") for name, model_id in model_details.items()}
30
+ tokenizers = {name: AutoTokenizer.from_pretrained(model_id) for name, model_id in model_details.items()}
31
+ for tokenizer in tokenizers.values():
32
+ tokenizer.use_default_system_prompt = False
33
+
34
  @spaces.GPU(enable_queue=True, duration=90)
35
  def generate(
 
36
  model_choice: str,
37
+ message: str,
38
  chat_history: list[tuple[str, str]],
39
  system_prompt: str,
40
  max_new_tokens: int = 1024,
 
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
+ model = models[model_choice]
47
+ tokenizer = tokenizers[model_choice]
48
+
49
+ conversation = []
50
+ if system_prompt:
51
+ conversation.append({"role": "system", "content": system_prompt})
52
+ for user, assistant in chat_history:
53
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
54
+ conversation.append({"role": "user", "content": message})
55
+
56
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
57
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
58
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
59
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
+ input_ids = input_ids.to(model.device)
61
 
62
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
63
+ generate_kwargs = dict(
64
+ {"input_ids": input_ids},
65
+ streamer=streamer,
66
+ max_new_tokens=max_new_tokens,
67
+ do_sample=True,
68
+ top_p=top_p,
69
+ top_k=top_k,
70
+ temperature=temperature,
71
+ num_beams=1,
72
+ repetition_penalty=repetition_penalty,
73
+ )
74
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
75
+ t.start()
76
+
77
+ outputs = []
78
+ for text in streamer:
79
+ outputs.append(text)
80
+ yield "".join(outputs)
81
+
82
+ chat_interface = gr.ChatInterface(
83
+ theme='ehristoforu/RE_Theme',
84
  fn=generate,
85
+ additional_inputs=[
86
+ gr.Textbox(label="System prompt", lines=6),
87
+ gr.Dropdown(label="Model Choice", choices=list(model_details.keys()), value="Quality-Model"),
88
+ gr.Slider(
89
+ label="Max new tokens",
90
+ minimum=1,
91
+ maximum=MAX_MAX_NEW_TOKENS,
92
+ step=1,
93
+ value=DEFAULT_MAX_NEW_TOKENS,
94
+ ),
95
+ gr.Slider(
96
+ label="Temperature",
97
+ minimum=0.1,
98
+ maximum=4.0,
99
+ step=0.1,
100
+ value=0.6,
101
+ ),
102
+ gr.Slider(
103
+ label="Top-p (nucleus sampling)",
104
+ minimum=0.05,
105
+ maximum=1.0,
106
+ step=0.05,
107
+ value=0.9,
108
+ ),
109
+ gr.Slider(
110
+ label="Top-k",
111
+ minimum=1,
112
+ maximum=1000,
113
+ step=1,
114
+ value=50,
115
+ ),
116
+ gr.Slider(
117
+ label="Repetition penalty",
118
+ minimum=1.0,
119
+ maximum=2.0,
120
+ step.05,
121
+ value=1.2,
122
+ ),
123
+ ],
124
+ stop_btn=None,
125
+ examples=[
126
+ ["Hello there! How are you doing?"],
127
+ ["Can you explain briefly to me what is the Python programming language?"],
128
+ ["Explain the plot of Cinderella in a sentence."],
129
+ ["How many hours does it take a man to eat a Helicopter?"],
130
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
131
  ],
 
 
 
132
  )
133
 
134
+ with gr.Blocks(css="style.css") as demo:
135
+ gr.Markdown(DESCRIPTION)
136
+ chat_interface.render()
137
+
138
  if __name__ == "__main__":
139
+ demo.queue(max_size=20).launch()