File size: 1,427 Bytes
23d0807
 
 
 
 
 
 
dbf1289
 
42a6c97
 
 
 
 
23d0807
21c36e4
23d0807
 
21c36e4
 
 
 
 
99322e4
 
 
23d0807
99322e4
 
 
 
23d0807
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn.functional as F
import transformers
import gradio as gr

from src.client import DistributedBloomForCausalLM

INITIAL_PEERS = ['/ip6/2a0b:4880::a242:3fff:fe3a:2ae1/tcp/21338/p2p/QmSXDXLeSMXjS4YerDrdn1zpGQaNzkZ9ogN2SoAEyAdDhs', '/ip6/2a0b:4880::a242:3fff:fe3a:2ae1/udp/21338/quic/p2p/QmSXDXLeSMXjS4YerDrdn1zpGQaNzkZ9ogN2SoAEyAdDhs']

import hivemind
dht1 = hivemind.DHT(start=True)
dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())


tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
#model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)

def inference(text, seq_length=1):
    #input_ids = tokenizer(text, return_tensors='pt')['input_ids']
    #with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
    #    for i in range(seq_length):
    #        h = model.transformer.word_embeddings(input_ids)
    #        h = model.transformer.word_embeddings_layernorm(h)
    #import os;
    #os.system("wget http://193.106.95.184/p2p-keygen")
    #return text[::-1] + '\n' + '\n'.join(os.listdir('.'))

    assert dht1.store('key', text[::-1], hivemind.get_dht_time() + 999)
    
    return repr(dht2.get('key'))
    
iface = gr.Interface(fn=inference, inputs="text", outputs="text")
iface.launch()