File size: 1,171 Bytes
23d0807
 
 
 
 
 
 
1e25943
dbf1289
4c16a19
42a6c97
 
 
 
23d0807
4c16a19
23d0807
 
4c16a19
 
 
 
 
 
 
7a31076
23d0807
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn.functional as F
import transformers
import gradio as gr

from src.client import DistributedBloomForCausalLM

INITIAL_PEERS = ['/ip4/193.106.95.184/tcp/443/p2p/QmSXDXLeSMXjS4YerDrdn1zpGQaNzkZ9ogN2SoAEyAdDhs']

import hivemind  # test that DHT instances work on localhost
dht1 = hivemind.DHT(start=True)
dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())


tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)

def inference(text, seq_length=1):
    input_ids = tokenizer(text, return_tensors='pt')['input_ids']
    with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
        for i in range(seq_length):
            h = model.transformer.word_embeddings(input_ids)
            h = model.transformer.word_embeddings_layernorm(h)
            h = remote_transformer.step(h)
            return repr(h)
    
iface = gr.Interface(fn=inference, inputs="text", outputs="text")
iface.launch()