gosha6037 commited on
Commit
aae4195
1 Parent(s): fd09410

Test old version

Browse files
Files changed (1) hide show
  1. app.py +88 -27
app.py CHANGED
@@ -1,35 +1,96 @@
1
  import sys
2
- sys.path.insert(0, './petals/')
3
-
4
  import torch
5
  import transformers
6
  import gradio as gr
7
 
 
 
8
  from src.client.remote_model import DistributedBloomForCausalLM
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- MODEL_NAME = "bigscience/bloom-petals"
12
-
13
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
14
- model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME)
15
-
16
- def inference(text, seq_length=1):
17
- input_ids = tokenizer([text], return_tensors="pt").input_ids
18
- output = model.generate(input_ids, max_new_tokens=seq_length)
19
- return tokenizer.batch_decode(output)[0]
20
-
21
- iface = gr.Interface(
22
- fn=inference,
23
- inputs=[
24
- gr.Textbox(lines=10, label="Input text"),
25
- gr.inputs.Slider(
26
- minimum=0,
27
- maximum=1000,
28
- step=1,
29
- default=42,
30
- label="Sequence length for generation"
31
- )
32
- ],
33
- outputs="text"
34
- )
35
- iface.launch()
 
 
1
  import sys
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
  import transformers
5
  import gradio as gr
6
 
7
+ sys.path.insert(0, './petals/')
8
+
9
  from src.client.remote_model import DistributedBloomForCausalLM
10
 
11
+ MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
12
+ INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
13
+
14
+ tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
15
+ model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3",
16
+ initial_peers=INITIAL_PEERS,
17
+ low_cpu_mem_usage=True, torch_dtype=torch.float32)
18
+
19
+
20
+ tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
21
+ model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
22
+
23
+ tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
24
+ model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
25
+
26
+ tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
27
+ model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
28
+
29
+
30
+ def predict(
31
+ input_text,
32
+ history=None,
33
+ person_description=None,
34
+ number_of_new_tokens=1000,
35
+ model_name=None,
36
+ del_hist=None
37
+ ):
38
+ if history is None or del_hist == 'delete history':
39
+ history = []
40
+ if model_name == 'DialoGPT-small':
41
+ model = model_DialoGPT_small
42
+ tokenizer = tokenizer_DialoGPT_small
43
+ elif model_name == 'DialoGPT-medium':
44
+ model = model_DialoGPT_medium
45
+ tokenizer = tokenizer_DialoGPT_medium
46
+ elif model_name == 'DialoGPT-large':
47
+ model = model_DialoGPT_large
48
+ tokenizer = tokenizer_DialoGPT_large
49
+ elif model_name == 'test-bloomd-6b3':
50
+ model = tokenizer_bloomd_6b3
51
+ tokenizer = model_bloomd_6b3
52
+ else:
53
+ model = model_DialoGPT_medium
54
+ tokenizer = tokenizer_DialoGPT_medium
55
+
56
+ person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
57
+ new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')
58
+
59
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
60
+ input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
61
+ max_token_count = number_of_new_tokens + len(input_with_desc_ids[0])
62
+ history = model.generate(input_with_desc_ids, max_length=max_token_count,
63
+ pad_token_id=tokenizer.eos_token_id).tolist()
64
+ history[0] = history[0][len(person_description_ids[0]):]
65
+
66
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
67
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
68
+ return response, history
69
+
70
 
71
+ gr.Interface(
72
+ fn=predict,
73
+ inputs=[
74
+ gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."),
75
+ "state",
76
+ gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."),
77
+ gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10),
78
+ gr.Radio(
79
+ label='Model name',
80
+ choices=[
81
+ 'DialoGPT-small',
82
+ 'DialoGPT-medium',
83
+ 'DialoGPT-large',
84
+ 'test-bloomd-6b3'
85
+ ]
86
+ ),
87
+ gr.Radio(
88
+ label='Delete history',
89
+ value="Don't delete history",
90
+ choices=[
91
+ 'delete history',
92
+ "Don't delete history"
93
+ ]),
94
+ ],
95
+ outputs=[gr.Chatbot(label='History of the dialogue'), "state"],
96
+ ).launch(),