Floki00 commited on
Commit
1acf699
1 Parent(s): 42cfc5f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import wget, os, io, ast
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ from genQC.pipeline.diffusion_pipeline import DiffusionPipeline
6
+ from genQC.inference.infer_srv import generate_srv_tensors, convert_tensors_to_srvs
7
+ from genQC.util import infer_torch_device
8
+
9
+ #--------------------------------
10
+ # download model into storage
11
+
12
+ save_destination = "saves/"
13
+
14
+ url_config = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
15
+ url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"
16
+
17
+ def download(url, dst_dir):
18
+ if not os.path.exists(dst_dir): os.mkdir(dst_dir)
19
+ filename = os.path.join(dst_dir, os.path.basename(url))
20
+ if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
21
+ return filename
22
+
23
+ config_file = download(url_config, save_destination)
24
+ weigths_file = download(url_weights, save_destination)
25
+
26
+ #--------------------------------
27
+ # setup
28
+
29
+ try:
30
+ pipeline
31
+ except:
32
+ pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
33
+ pipeline.scheduler.set_timesteps(20)
34
+
35
+ is_gpu_busy = False
36
+ def get_correct_qcs_image(srv, num_of_qubits, max_gates, g):
37
+ global is_gpu_busy
38
+
39
+ out_tensor = generate_srv_tensors(pipeline, f"Generate SRV: {srv}", samples=6, system_size=num_of_qubits, num_of_qubits=num_of_qubits, max_gates=max_gates, g=g)
40
+ qc_list, _, svr_list = convert_tensors_to_srvs(out_tensor, pipeline.gate_pool)
41
+
42
+ fig, axs = plt.subplots(2, 3, figsize=(15,7), constrained_layout=True)
43
+ for qc,is_svr,ax in zip(qc_list, svr_list, axs.flatten()):
44
+ qc.draw("mpl", plot_barriers=False, ax=ax)
45
+ ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
46
+
47
+ buf = io.BytesIO()
48
+ fig.savefig(buf)
49
+ buf.seek(0)
50
+ return Image.open(buf)
51
+
52
+ #--------------------------------
53
+ # run
54
+
55
+ st.title("genQC · Generative Quantum Circuits")
56
+ st.write("""
57
+ Generating quantum circuits with diffusion models. Official demo of [[paper-arxiv]](https://arxiv.org/abs/2311.02041) [[code-repo]](https://github.com/FlorianFuerrutter/genQC).
58
+ """)
59
+
60
+ col1, col2 = st.columns(2)
61
+
62
+ srv = col1.text_input('SRV', "[1,1,1,2,2]")
63
+ num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=2)
64
+ max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
65
+ g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=7.5)
66
+
67
+ if col1.button('Generate circuits'):
68
+ image = get_correct_qcs_image(ast.literal_eval(srv), num_of_qubits, max_gates, g)
69
+ col2.image(image, use_column_width=True)