Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 58 additions & 34 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1554,14 +1554,17 @@ impl DAGCircuit {
.map(|s| s.extract())
.collect::<PyResult<HashSet<String>>>()?;
for wire in wires {
let nodes_found = self.nodes_on_wire(wire, true).into_iter().any(|node| {
let weight = self.dag.node_weight(node).unwrap();
if let NodeType::Operation(packed) = weight {
!ignore_set.contains(packed.op.name())
} else {
false
}
});
let nodes_found = self
.nodes_on_wire(wire)
.filter(|node| matches!(self.dag[*node], NodeType::Operation(_)))
.any(|node| {
let weight = self.dag.node_weight(node).unwrap();
if let NodeType::Operation(packed) = weight {
!ignore_set.contains(packed.op.name())
} else {
false
}
});

if !nodes_found {
result.push(match wire {
Expand Down Expand Up @@ -4167,8 +4170,8 @@ impl DAGCircuit {
})?;

let nodes = self
.nodes_on_wire(wire, only_ops)
.into_iter()
.nodes_on_wire(wire)
.filter(|node| !only_ops || matches!(self.dag[*node], NodeType::Operation(_)))
.map(|n| self.get_node(py, n))
.collect::<PyResult<Vec<_>>>()?;
Ok(PyTuple::new(py, nodes)?.into_any().try_iter()?.unbind())
Expand Down Expand Up @@ -6238,33 +6241,20 @@ impl DAGCircuit {
self.global_phase = Param::Float(angle.rem_euclid(::std::f64::consts::TAU));
}

/// Get the nodes on the given wire.
/// Get the nodes on a given wire as a non-allocating iterator
///
/// Note: result is empty if the wire is not in the DAG.
pub fn nodes_on_wire(&self, wire: Wire, only_ops: bool) -> Vec<NodeIndex> {
let mut nodes = Vec::new();
let mut current_node = match wire {
Wire::Qubit(qubit) => self.qubit_io_map.get(qubit.index()).map(|x| x[0]),
Wire::Clbit(clbit) => self.clbit_io_map.get(clbit.index()).map(|x| x[0]),
Wire::Var(var) => self.var_io_map.get(var.index()).map(|x| x[0]),
/// This will panic if the wire is not in the circuit
pub fn nodes_on_wire(&self, wire: Wire) -> impl Iterator<Item = NodeIndex> {
let start_node = match wire {
Wire::Qubit(qubit) => self.qubit_io_map[qubit.index()][0],
Wire::Clbit(clbit) => self.clbit_io_map[clbit.index()][0],
Wire::Var(var) => self.var_io_map[var.index()][0],
};

while let Some(node) = current_node {
if only_ops {
let node_weight = self.dag.node_weight(node).unwrap();
if let NodeType::Operation(_) = node_weight {
nodes.push(node);
}
} else {
nodes.push(node);
}

let edges = self.dag.edges_directed(node, Outgoing);
current_node = edges
.into_iter()
.find_map(|edge| (*edge.weight() == wire).then_some(edge.target()));
NodesOnWireIter {
dag: self,
wire,
next_node: Some(start_node),
}
nodes
}

fn remove_idle_wire(&mut self, wire: Wire) {
Expand Down Expand Up @@ -8117,6 +8107,40 @@ impl DAGCircuit {
}
}

struct NodesOnWireIter<'a> {
dag: &'a DAGCircuit,
wire: Wire,
next_node: Option<NodeIndex>,
}

impl<'a> Iterator for NodesOnWireIter<'a> {
type Item = NodeIndex;

fn size_hint(&self) -> (usize, Option<usize>) {
// If the wire is empty there are the input and output nodes which is the minimum
// If the dag is all on a single wire we would have all the operations + the 2 io
// nodes.
(2, Some(self.dag.num_ops() + 2))
}

fn next(&mut self) -> Option<Self::Item> {
let out_node = self.next_node?;
self.next_node = self
.dag
.dag
.edges_directed(out_node, Outgoing)
.filter_map(|e: EdgeReference<'a, Wire>| {
if e.weight() == &self.wire {
Some(e.target())
} else {
None
}
})
.next();
Some(out_node)
}
}

pub struct DAGCircuitBuilder {
dag: DAGCircuit,
last_clbits: Vec<Option<NodeIndex>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/transpiler/src/passes/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub fn analyze_commutations(
for qubit in 0..dag.num_qubits() {
let wire = Wire::Qubit(Qubit(qubit as u32));

for current_gate_idx in dag.nodes_on_wire(wire, false) {
for current_gate_idx in dag.nodes_on_wire(wire) {
// get the commutation set associated with the current wire, or create a new
// index set containing the current gate
let commutation_entry = commutation_set
Expand Down