genQC / app.py
Floki00's picture
Update app.py
8f69aac
raw
history blame contribute delete
No virus
3.57 kB
import streamlit as st
import wget, os, io, ast
import matplotlib.pyplot as plt
from PIL import Image
from genQC.pipeline.diffusion_pipeline import DiffusionPipeline
from genQC.inference.infer_srv import generate_srv_tensors, convert_tensors_to_srvs
from genQC.util import infer_torch_device
#--------------------------------
# download model into storage
save_destination = "saves/"
url_config = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"
def download(url, dst_dir):
if not os.path.exists(dst_dir): os.mkdir(dst_dir)
filename = os.path.join(dst_dir, os.path.basename(url))
if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
return filename
config_file = download(url_config, save_destination)
weigths_file = download(url_weights, save_destination)
#--------------------------------
# setup
@st.cache_resource
def load_pipeline():
pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
pipeline.scheduler.set_timesteps(20)
return pipeline
pipeline = load_pipeline()
is_gpu_busy = False
def get_qcs(srv, num_of_qubits, max_gates, g):
global is_gpu_busy
with st.status("Generation started", expanded=True) as status:
st.write("Generating tensors...")
out_tensor = generate_srv_tensors(pipeline, f"Generate SRV: {srv}", samples=6, system_size=num_of_qubits, num_of_qubits=num_of_qubits, max_gates=max_gates, g=g)
st.write("Converting to circuits...")
qc_list, _, srv_list = convert_tensors_to_srvs(out_tensor, pipeline.gate_pool)
st.write("Plotting...")
fig, axs = plt.subplots(3, 2, figsize=(7,10), constrained_layout=True, dpi=120)
for ax in axs.flatten():
ax.axis('off')
ax.text(0.5, 0.5,"Circuit generated with errors")
for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()):
ax.clear()
qc.draw("mpl", plot_barriers=False, ax=ax, style="clifford")
ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
status.update(label="Generation complete!", state="complete", expanded=False)
# buf = io.BytesIO()
# fig.savefig(buf)
# buf.seek(0)
# return Image.open(buf)
return fig
#--------------------------------
# run
st.title("genQC · Generative Quantum Circuits")
st.write("""
Generating quantum circuits with diffusion models. Official demo of [[paper-arxiv]](https://arxiv.org/abs/2311.02041) [[code-repo]](https://github.com/FlorianFuerrutter/genQC).
""")
col1, col2 = st.columns(2)
srv = col1.text_input('SRV', "[1,1,1,2,2]")
num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=2)
max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=7.5)
srv_list = ast.literal_eval(srv)
if len(srv_list)!=num_of_qubits:
st.warning(f'Number of qubits does not match with given SRV {srv_list}. This could result in error-circuits!', icon="⚠️")
if col1.button('Generate circuits'):
fig = get_qcs(srv_list, num_of_qubits, max_gates, g)
# col2.image(image, use_column_width=True)
col2.pyplot(fig)