Generalized Multiplexing -- Simple example¶
This notebook shows a brief tutorial example for how to define a Qubrick that uses a multiplexor. This is currently non-functional (in the sense that no statevector sim is checked in the tutorial) but is rather intended to show the process for writing such qubricks.
The example we shall show is for the unary encoding of the unary iteration circuit from Even more efficient quantum computations of chemistry through tensor hypercontraction ⧉ (but where we leave the inner prepare and select operations uncompiled). For drawing purposes, we define two dummy qubricks:
%load_ext autoreload
%autoreload 2
from psiqworkbench import Qubits, QPU, Qubrick
from workbench_algorithms import ZeroAncMultiplexor, BinaryTreeMultiplexor, OneAncMultiplexor, SawtoothMultiplexor, SelectNaive
import numpy as np
class PrepBox(Qubrick):
def _compute(self, target, ctrl=0):
target.box_open("Prep", cond=ctrl)
target.box_close("Prep", cond=ctrl)
class SelectBox(Qubrick):
def _compute(self, index_reg, target, data, ctrl=0):
(index_reg | target).box_open("select", cond=ctrl)
(index_reg | target).box_close("select", cond=ctrl)
To use the unary iteration qubrick for actual simulations, one could replace these with functional counterparts.
Next we define the main Qubrick:
# Boilerplate, every qubrick needs its subqubricks to be defined at initialization.
# In this case, we need internal prep and select qubricks and a multiplexor
class UnaryIteration(Qubrick):
def __init__(self, prep, select, multiplexor, **kwargs):
self.prep = prep
self.select = select
self.multiplexor = multiplexor
super().__init__(**kwargs)
# Pass whatever you like to the _compute method. Here we could allocate some temp qubits, but for now just pass everything
def _compute(self, index_reg, ell_reg, psi, num_terms, inner_data, ctrl=0):
multiplexor = self.multiplexor
# Apply an operation at every index
used_indices = [1] * num_terms
# This is where we use our factory to define our multiplex function
_multiplexing_function = self.generate_multiplex_func(ell_reg, psi, inner_data, used_indices)
# Compute the multiplexor using the multiplexing function
multiplexor.compute(index_reg, _multiplexing_function, used_indices, ctrl=ctrl)
# Here's where all the functionality of the qubrick gets defined -- note that the signature for this qubrick can also be quite long
def generate_multiplex_func(self, ell_reg, psi, inner_data, used_indices):
# we can still access all the qubrick methods
qc = self.get_qc()
prep = self.prep
select = self.select
# Here's our multiplex function definition -- note the simple signature, only three terms allowed!
def _multiplex_func(index, index_reg, ctrl=0):
with prep.computed(ell_reg, dagger=True):
# Reflection about I-2|0><0|
qc.z(0, index_reg | (ell_reg == 0) | ctrl)
if index != max(used_indices) + 1:
select.compute(ell_reg, psi, inner_data)
# Make sure to return the multiplex function so it can be used
return _multiplex_func
Most of the functionality is defined in the multiplex_func that is returned from the generate_multiplex_func method. In this way, we can "freeze out" registers and other parameters without having to worry about how the multiplexor will handle it.
Users shouldn't access the multiplexing functions themselves. They should be defined within the Qubrick and called internally as in the example above. From the user's perspective, the Qubrick just does the desired operations given an index -- and we can choose whatever multiplexor we like to iterate over those indices!
qc = QPU()
qc.reset(6)
index = Qubits(4, "index", qc)
ell_reg = Qubits(1, "ell", qc)
psi = Qubits(1, "psi", qc)
unary_iter = UnaryIteration(PrepBox(), SelectBox(), ZeroAncMultiplexor())
unary_iter.compute(index, ell_reg, psi, 12, inner_data=[1])
qc.draw()
qc = QPU()
qc.reset(7)
index = Qubits(4, "index", qc)
ell_reg = Qubits(1, "ell", qc)
psi = Qubits(1, "psi", qc)
unary_iter = UnaryIteration(PrepBox(), SelectBox(), OneAncMultiplexor())
unary_iter.compute(index, ell_reg, psi, 12, inner_data=[1])
qc.draw()
qc = QPU()
qc.reset(10)
index = Qubits(4, "index", qc)
ell_reg = Qubits(1, "ell", qc)
psi = Qubits(1, "psi", qc)
unary_iter = UnaryIteration(PrepBox(), SelectBox(), SawtoothMultiplexor())
unary_iter.compute(index, ell_reg, psi, 12, inner_data=[1])
qc.draw()
qc = QPU()
qc.reset(10)
index = Qubits(4, "index", qc)
ell_reg = Qubits(1, "ell", qc)
psi = Qubits(1, "psi", qc)
unary_iter = UnaryIteration(PrepBox(), SelectBox(), BinaryTreeMultiplexor())
unary_iter.compute(index, ell_reg, psi, 12, inner_data=[1])
qc.draw()