kz209 commited on
Commit
1921336
·
1 Parent(s): b97cda3
pages/arena.py CHANGED
@@ -1,6 +1,49 @@
1
- from utils.multiple_stream import create_interface
 
 
 
 
 
 
 
 
 
 
2
 
3
  def create_arena():
4
- demo = create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  return demo
 
 
 
 
 
 
1
+ #from utils.multiple_stream import create_interface
2
+ import random
3
+ import gradio as gr
4
+ import json
5
+ import logging
6
+ import gc
7
+ import torch
8
+
9
+ from utils.data import dataset
10
+ from utils.multiple_stream import stream_data
11
+ from summarization_playground import get_model_batch_generation
12
 
13
  def create_arena():
14
+ with gr.Blocks() as demo:
15
+ with gr.Group():
16
+ gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")
17
+
18
+ with json.loads("prompt/prompt.json", "r") as file:
19
+ json_data = file.read()
20
+ prompts = json.loads(json_data)
21
+
22
+ datapoint = random.choice(dataset)
23
+ datapoint = datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
24
+ submit_button = gr.Button("✨ Submit ✨")
25
+
26
+ with gr.Row():
27
+ columns = [gr.Textbox(label=f"Column {i+1}", lines=10) for i in range(3)]
28
+
29
+ content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts]
30
+ model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
31
+
32
+ def start_streaming():
33
+ for data in stream_data(content_list, model):
34
+ updates = [gr.update(value=data[i]) for i in range(len(columns))]
35
+ yield tuple(updates)
36
+
37
+ submit_button.click(
38
+ fn=start_streaming,
39
+ inputs=[],
40
+ outputs=columns,
41
+ show_progress=False
42
+ )
43
 
44
  return demo
45
+
46
+ if __name__ == "__main__":
47
+ demo = create_arena()
48
+ demo.queue()
49
+ demo.launch()
pages/summarization_playground.py CHANGED
@@ -33,8 +33,24 @@ Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 3
33
  random_label: ""
34
  }
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def generate_answer(sources, model_name, prompt):
37
- content = prompt + '\n{' + sources + '}\n\nsummary:'
38
  global __model_on_gpu__
39
 
40
  if __model_on_gpu__ != model_name:
@@ -47,6 +63,8 @@ def generate_answer(sources, model_name, prompt):
47
  model[model_name] = Model(model_name)
48
  __model_on_gpu__ = model_name
49
 
 
 
50
  answer = model[model_name].gen(content)
51
 
52
  return answer
@@ -68,7 +86,7 @@ def update_input(example):
68
  def create_summarization_interface():
69
  with gr.Blocks() as demo:
70
  gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")
71
-
72
  with gr.Row():
73
  example_dropdown = gr.Dropdown(choices=list(examples.keys()), label="Choose an example", value=random_label)
74
  model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
 
33
  random_label: ""
34
  }
35
 
36
+
37
+ def get_model_batch_generation(model_name):
38
+ global __model_on_gpu__
39
+
40
+ if __model_on_gpu__ != model_name:
41
+ if __model_on_gpu__:
42
+ logging.info(f"delete model {__model_on_gpu__}")
43
+ del model[__model_on_gpu__]
44
+ gc.collect()
45
+ torch.cuda.empty_cache()
46
+
47
+ model[model_name] = Model(model_name)
48
+ __model_on_gpu__ = model_name
49
+
50
+ return model[model_name]
51
+
52
+
53
  def generate_answer(sources, model_name, prompt):
 
54
  global __model_on_gpu__
55
 
56
  if __model_on_gpu__ != model_name:
 
63
  model[model_name] = Model(model_name)
64
  __model_on_gpu__ = model_name
65
 
66
+ content = prompt + '\n{' + sources + '}\n\nsummary:'
67
+
68
  answer = model[model_name].gen(content)
69
 
70
  return answer
 
86
  def create_summarization_interface():
87
  with gr.Blocks() as demo:
88
  gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")
89
+
90
  with gr.Row():
91
  example_dropdown = gr.Dropdown(choices=list(examples.keys()), label="Choose an example", value=random_label)
92
  model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
prompt/prompt.ipynb ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import json\n",
10
+ "\n",
11
+ "prompts = [\n",
12
+ " \"\"\"Please summarize the following conversation by highlighting the key points and main topics discussed. Include any important conclusions or decisions made during the conversation.\n",
13
+ "Conversation:\"\"\",\n",
14
+ " \"\"\"Generate a concise summary of the conversation below. Focus on the main arguments, the flow of the discussion, and any significant outcomes or agreements reached. Make sure to capture the essence of the dialogue without including extraneous details.\n",
15
+ "Conversation:\"\"\",\n",
16
+ " \"\"\"Provide a brief overview of the conversation provided. Summarize the main ideas exchanged, the context of the discussion, and any resolutions or actions decided. Ensure the summary is clear and easy to understand for someone who wasn't part of the conversation\"\"\"\n",
17
+ "]"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 3,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "with open(\"prompt.json\", \"w\") as f:\n",
27
+ " json.dump(prompts, f)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": []
36
+ }
37
+ ],
38
+ "metadata": {
39
+ "kernelspec": {
40
+ "display_name": "Python 3",
41
+ "language": "python",
42
+ "name": "python3"
43
+ },
44
+ "language_info": {
45
+ "codemirror_mode": {
46
+ "name": "ipython",
47
+ "version": 3
48
+ },
49
+ "file_extension": ".py",
50
+ "mimetype": "text/x-python",
51
+ "name": "python",
52
+ "nbconvert_exporter": "python",
53
+ "pygments_lexer": "ipython3",
54
+ "version": "3.11.9"
55
+ }
56
+ },
57
+ "nbformat": 4,
58
+ "nbformat_minor": 2
59
+ }
prompt/prompt.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["Please summarize the following conversation by highlighting the key points and main topics discussed. Include any important conclusions or decisions made during the conversation.\nConversation:", "Generate a concise summary of the conversation below. Focus on the main arguments, the flow of the discussion, and any significant outcomes or agreements reached. Make sure to capture the essence of the dialogue without including extraneous details.\nConversation:", "Provide a brief overview of the conversation provided. Summarize the main ideas exchanged, the context of the discussion, and any resolutions or actions decided. Ensure the summary is clear and easy to understand for someone who wasn't part of the conversation"]
utils/model.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoTokenizer
2
  import transformers
3
  import torch
4
 
@@ -12,26 +12,32 @@ login(token = os.getenv('HF_TOKEN'))
12
  class Model(torch.nn.Module):
13
  number_of_models = 0
14
  __model_list__ = [
 
15
  "lmsys/vicuna-7b-v1.5",
16
  "google-t5/t5-large",
17
  "mistralai/Mistral-7B-Instruct-v0.1",
18
  "meta-llama/Meta-Llama-3.1-8B-Instruct"
19
  ]
20
 
21
- def __init__(self, model_name="lmsys/vicuna-7b-v1.5") -> None:
22
  super(Model, self).__init__()
23
 
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  self.name = model_name
26
 
27
  logging.info(f'start loading model {self.name}')
28
- self.pipeline = transformers.pipeline(
29
- "summarization" if model_name=="google-t5/t5-large" else "text-generation",
30
- model=model_name,
31
- tokenizer=self.tokenizer,
32
- torch_dtype=torch.bfloat16,
33
- device_map="auto",
34
- )
 
 
 
 
 
35
  logging.info(f'Loaded model {self.name}')
36
 
37
  self.update()
@@ -49,25 +55,32 @@ class Model(torch.nn.Module):
49
  def return_model(self):
50
  return self.pipeline
51
 
52
- def gen(self, content, temp=0.1, max_length=500):
53
- if self.name == "google-t5/t5-large":
54
- sequences = self.pipeline(
55
- content,
56
- max_new_tokens=max_length,
57
- do_sample=True,
58
- temperature=temp,
59
- num_return_sequences=1,
60
- eos_token_id=self.tokenizer.eos_token_id,
61
- )
62
- return sequences[-1]['summary_text']
 
 
 
 
 
 
 
 
 
63
  else:
64
- sequences = self.pipeline(
65
- content,
66
  max_new_tokens=max_length,
67
  do_sample=True,
68
  temperature=temp,
69
- num_return_sequences=1,
70
- eos_token_id=self.tokenizer.eos_token_id,
71
- return_full_text=False
72
  )
73
- return sequences[-1]['generated_text']
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextStreamer
2
  import transformers
3
  import torch
4
 
 
12
  class Model(torch.nn.Module):
13
  number_of_models = 0
14
  __model_list__ = [
15
+ "Qwen/Qwen2-1.5B-Instruct",
16
  "lmsys/vicuna-7b-v1.5",
17
  "google-t5/t5-large",
18
  "mistralai/Mistral-7B-Instruct-v0.1",
19
  "meta-llama/Meta-Llama-3.1-8B-Instruct"
20
  ]
21
 
22
+ def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
23
  super(Model, self).__init__()
24
 
25
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
26
  self.name = model_name
27
 
28
  logging.info(f'start loading model {self.name}')
29
+
30
+ if model_name == "google-t5/t5-large":
31
+ # For T5 or any other Seq2Seq model
32
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
33
+ model_name, torch_dtype=torch.bfloat16, device_map="auto"
34
+ )
35
+ else:
36
+ # For GPT-like models or other causal language models
37
+ self.model = AutoModelForCausalLM.from_pretrained(
38
+ model_name, torch_dtype=torch.bfloat16, device_map="auto"
39
+ )
40
+
41
  logging.info(f'Loaded model {self.name}')
42
 
43
  self.update()
 
55
  def return_model(self):
56
  return self.pipeline
57
 
58
+ def gen(self, content_list, temp=0.1, max_length=500, streaming=False):
59
+ # Convert list of texts to input IDs
60
+ input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
+
62
+ if streaming:
63
+ # Prepare streamers for each input
64
+ streamers = [TextStreamer(self.tokenizer, skip_prompt=True) for _ in content_list]
65
+
66
+ # Stream the output token by token for each input text
67
+ for i, streamer in enumerate(streamers):
68
+ for output in self.model.generate(
69
+ input_ids[i].unsqueeze(0), # Process each input separately
70
+ max_new_tokens=max_length,
71
+ do_sample=True,
72
+ temperature=temp,
73
+ eos_token_id=self.tokenizer.eos_token_id,
74
+ return_dict_in_generate=True,
75
+ output_scores=True,
76
+ streamer=streamer):
77
+ pass # TextStreamer automatically handles the streaming, no need to manually handle the output
78
  else:
79
+ outputs = self.model.generate(
80
+ input_ids,
81
  max_new_tokens=max_length,
82
  do_sample=True,
83
  temperature=temp,
84
+ eos_token_id=self.tokenizer.eos_token_id
 
 
85
  )
86
+ return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
utils/multiple_stream.py CHANGED
@@ -3,6 +3,9 @@ import random
3
  from time import sleep
4
  import gradio as gr
5
 
 
 
 
6
  TEST = """ Test of Time. A Benchmark for Evaluating LLMs on Temporal Reasoning. Large language models (LLMs) have
7
  showcased remarkable reasoning capabilities, yet they remain susceptible to errors, particularly in temporal
8
  reasoning tasks involving complex temporal logic. """
@@ -16,16 +19,16 @@ def generate_data_test():
16
  for word in temp.split(" "):
17
  yield word + " "
18
 
19
- def stream_data(progress=gr.Progress()):
20
- """Stream data to all columns"""
21
- outputs = ["", "", ""]
22
- generators = [generate_data_test() for _ in range(3)]
23
-
24
  while True:
25
  updated = False
26
- for i, gen in enumerate(generators):
27
  try:
28
- word = next(gen)
29
  outputs[i] += word
30
  updated = True
31
  except StopIteration:
@@ -35,24 +38,31 @@ def stream_data(progress=gr.Progress()):
35
  break
36
 
37
  yield tuple(outputs)
38
- sleep(0.01)
39
 
40
  def create_interface():
41
- with gr.Group():
42
- with gr.Row():
43
- col1 = gr.Textbox(label="Column 1", lines=10)
44
- col2 = gr.Textbox(label="Column 2", lines=10)
45
- col3 = gr.Textbox(label="Column 3", lines=10)
46
-
47
- start_btn = gr.Button("Start Streaming")
48
-
49
- start_btn.click(
50
- fn=stream_data,
51
- outputs=[col1, col2, col3],
52
- show_progress=False
53
- )
 
 
 
 
 
 
 
 
54
 
55
- #return demo
56
 
57
  if __name__ == "__main__":
58
  demo = create_interface()
 
3
  from time import sleep
4
  import gradio as gr
5
 
6
+ from utils.model import Model
7
+
8
+
9
  TEST = """ Test of Time. A Benchmark for Evaluating LLMs on Temporal Reasoning. Large language models (LLMs) have
10
  showcased remarkable reasoning capabilities, yet they remain susceptible to errors, particularly in temporal
11
  reasoning tasks involving complex temporal logic. """
 
19
  for word in temp.split(" "):
20
  yield word + " "
21
 
22
+ def stream_data(content_list, model):
23
+ """Stream data to three columns"""
24
+ outputs = ["" for _ in content_list]
25
+
26
+ # Use the gen method to handle batch generation
27
  while True:
28
  updated = False
29
+ for i, content in enumerate(content_list):
30
  try:
31
+ word = next(model.gen([content], streaming=True)) # Wrap content in a list to match expected input type
32
  outputs[i] += word
33
  updated = True
34
  except StopIteration:
 
38
  break
39
 
40
  yield tuple(outputs)
41
+
42
 
43
  def create_interface():
44
+ with gr.Blocks() as demo:
45
+ with gr.Group():
46
+ with gr.Row():
47
+ columns = [gr.Textbox(label=f"Column {i+1}", lines=10) for i in range(3)]
48
+
49
+ start_btn = gr.Button("Start Streaming")
50
+
51
+ def start_streaming():
52
+ content_list = [col.value for col in columns] # Get input texts from text boxes
53
+ for data in stream_data(content_list):
54
+ updates = [gr.update(value=data[i]) for i in range(len(columns))]
55
+ yield tuple(updates)
56
+
57
+ start_btn.click(
58
+ fn=start_streaming,
59
+ inputs=[],
60
+ outputs=columns,
61
+ show_progress=False
62
+ )
63
+
64
+ return demo
65
 
 
66
 
67
  if __name__ == "__main__":
68
  demo = create_interface()