File size: 3,570 Bytes
1acf699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f26c118
 
a9c34c4
f26c118
 
 
 
 
a9c34c4
1acf699
f26c118
1acf699
 
adc0e7a
 
 
 
 
8f69aac
adc0e7a
 
 
0a95580
 
 
 
 
 
8f69aac
0a95580
a0c5805
adc0e7a
 
 
47a9ca5
 
 
 
 
1acf699
 
 
 
 
0a95580
 
 
1acf699
 
 
 
865ea06
0a95580
 
1acf699
adc0e7a
 
0a95580
adc0e7a
865ea06
adc0e7a
47a9ca5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
import wget, os, io, ast
import matplotlib.pyplot as plt
from PIL import Image
from genQC.pipeline.diffusion_pipeline import DiffusionPipeline
from genQC.inference.infer_srv import generate_srv_tensors, convert_tensors_to_srvs
from genQC.util import infer_torch_device

#--------------------------------
# download model into storage

save_destination = "saves/"

url_config  = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"

def download(url, dst_dir):
    if not os.path.exists(dst_dir): os.mkdir(dst_dir)
    filename = os.path.join(dst_dir, os.path.basename(url))
    if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
    return filename

config_file  = download(url_config, save_destination)
weigths_file = download(url_weights, save_destination)

#--------------------------------
# setup

@st.cache_resource
def load_pipeline():
    pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())  
    pipeline.scheduler.set_timesteps(20)   
    return pipeline

pipeline = load_pipeline()


is_gpu_busy = False
def get_qcs(srv, num_of_qubits, max_gates, g):
    global is_gpu_busy

    with st.status("Generation started", expanded=True) as status:
        st.write("Generating tensors...")
        out_tensor = generate_srv_tensors(pipeline, f"Generate SRV: {srv}", samples=6, system_size=num_of_qubits, num_of_qubits=num_of_qubits, max_gates=max_gates, g=g)
        
        st.write("Converting to circuits...")
        qc_list, _, srv_list = convert_tensors_to_srvs(out_tensor, pipeline.gate_pool)
        
        st.write("Plotting...")
        fig, axs = plt.subplots(3, 2, figsize=(7,10), constrained_layout=True, dpi=120)

        for ax in axs.flatten():
            ax.axis('off')
            ax.text(0.5, 0.5,"Circuit generated with errors")
            
        
        for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()): 
            ax.clear()
            qc.draw("mpl", plot_barriers=False, ax=ax, style="clifford")
            ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
        status.update(label="Generation complete!", state="complete", expanded=False)
    
    # buf = io.BytesIO()
    # fig.savefig(buf)
    # buf.seek(0)
    # return Image.open(buf)
    return fig

#--------------------------------
# run

st.title("genQC · Generative Quantum Circuits")
st.write("""
Generating quantum circuits with diffusion models. Official demo of [[paper-arxiv]](https://arxiv.org/abs/2311.02041) [[code-repo]](https://github.com/FlorianFuerrutter/genQC).
""")

col1, col2 = st.columns(2)

srv           = col1.text_input('SRV', "[1,1,1,2,2]")
num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=2)
max_gates     = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
g             = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=7.5)

srv_list = ast.literal_eval(srv)
if len(srv_list)!=num_of_qubits:
    st.warning(f'Number of qubits does not match with given SRV {srv_list}. This could result in error-circuits!', icon="⚠️")

if col1.button('Generate circuits'):    
    fig = get_qcs(srv_list, num_of_qubits, max_gates, g)
    # col2.image(image, use_column_width=True)
    col2.pyplot(fig)