File size: 3,647 Bytes
1acf699 fa0e40e 1acf699 1ff7cbe 1acf699 1ff7cbe 1acf699 1ff7cbe 1acf699 f26c118 1ff7cbe f26c118 a9c34c4 1acf699 f26c118 1acf699 adc0e7a 8f69aac adc0e7a 0a95580 8f69aac 0a95580 1ff7cbe adc0e7a 47a9ca5 1acf699 0a95580 1acf699 1ff7cbe 0a95580 1ff7cbe 1acf699 adc0e7a 0a95580 adc0e7a 865ea06 adc0e7a 47a9ca5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import streamlit as st
import os, io, ast #wget
import matplotlib.pyplot as plt
from PIL import Image
from genQC.pipeline.diffusion_pipeline import DiffusionPipeline
from genQC.inference.infer_srv import generate_srv_tensors, convert_tensors_to_srvs
from genQC.util import infer_torch_device
#--------------------------------
# download model into storage
#save_destination = "saves/"
#url_config = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
#url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"
#def download(url, dst_dir):
# if not os.path.exists(dst_dir): os.mkdir(dst_dir)
# filename = os.path.join(dst_dir, os.path.basename(url))
# if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
# return filename
#config_file = download(url_config, save_destination)
#weigths_file = download(url_weights, save_destination)
#--------------------------------
# setup
@st.cache_resource
def load_pipeline():
#pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
pipeline = DiffusionPipeline.from_pretrained("Floki00/qc_srv_3to8qubit", "cpu")
pipeline.scheduler.set_timesteps(20)
return pipeline
pipeline = load_pipeline()
is_gpu_busy = False
def get_qcs(srv, num_of_qubits, max_gates, g):
global is_gpu_busy
with st.status("Generation started", expanded=True) as status:
st.write("Generating tensors...")
out_tensor = generate_srv_tensors(pipeline, f"Generate SRV: {srv}", samples=6, system_size=num_of_qubits, num_of_qubits=num_of_qubits, max_gates=max_gates, g=g)
st.write("Converting to circuits...")
qc_list, _, srv_list = convert_tensors_to_srvs(out_tensor, pipeline.gate_pool)
st.write("Plotting...")
fig, axs = plt.subplots(3, 2, figsize=(7,10), constrained_layout=True, dpi=120)
for ax in axs.flatten():
ax.axis('off')
ax.text(0.5, 0.5,"Circuit generated with errors")
for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()):
ax.clear()
qc.draw("mpl", plot_barriers=False, ax=ax)
ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
status.update(label="Generation complete!", state="complete", expanded=False)
# buf = io.BytesIO()
# fig.savefig(buf)
# buf.seek(0)
# return Image.open(buf)
return fig
#--------------------------------
# run
st.title("genQC · Generative Quantum Circuits")
st.write("""
Generating quantum circuits with diffusion models. Official demo of [[paper-arxiv]](https://arxiv.org/abs/2311.02041) [[code-repo]](https://github.com/FlorianFuerrutter/genQC).
""")
col1, col2 = st.columns(2)
srv = col1.text_input('SRV', "[1,1,1,2,2,2]")
num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=3)
max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=10)
srv_list = ast.literal_eval(srv)
if len(srv_list)!=num_of_qubits:
st.warning(f'Number of qubits does not match with given SRV {srv_list}. This could result in error-circuits!', icon="⚠️")
if col1.button('Generate circuits'):
fig = get_qcs(srv_list, num_of_qubits, max_gates, g)
# col2.image(image, use_column_width=True)
col2.pyplot(fig) |