diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index a0c06be401b8..99e966eb68d7 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -1554,14 +1554,17 @@ impl DAGCircuit { .map(|s| s.extract()) .collect::>>()?; for wire in wires { - let nodes_found = self.nodes_on_wire(wire, true).into_iter().any(|node| { - let weight = self.dag.node_weight(node).unwrap(); - if let NodeType::Operation(packed) = weight { - !ignore_set.contains(packed.op.name()) - } else { - false - } - }); + let nodes_found = self + .nodes_on_wire(wire) + .filter(|node| matches!(self.dag[*node], NodeType::Operation(_))) + .any(|node| { + let weight = self.dag.node_weight(node).unwrap(); + if let NodeType::Operation(packed) = weight { + !ignore_set.contains(packed.op.name()) + } else { + false + } + }); if !nodes_found { result.push(match wire { @@ -4167,8 +4170,8 @@ impl DAGCircuit { })?; let nodes = self - .nodes_on_wire(wire, only_ops) - .into_iter() + .nodes_on_wire(wire) + .filter(|node| !only_ops || matches!(self.dag[*node], NodeType::Operation(_))) .map(|n| self.get_node(py, n)) .collect::>>()?; Ok(PyTuple::new(py, nodes)?.into_any().try_iter()?.unbind()) @@ -6238,33 +6241,20 @@ impl DAGCircuit { self.global_phase = Param::Float(angle.rem_euclid(::std::f64::consts::TAU)); } - /// Get the nodes on the given wire. + /// Get the nodes on a given wire as a non-allocating iterator /// - /// Note: result is empty if the wire is not in the DAG. - pub fn nodes_on_wire(&self, wire: Wire, only_ops: bool) -> Vec { - let mut nodes = Vec::new(); - let mut current_node = match wire { - Wire::Qubit(qubit) => self.qubit_io_map.get(qubit.index()).map(|x| x[0]), - Wire::Clbit(clbit) => self.clbit_io_map.get(clbit.index()).map(|x| x[0]), - Wire::Var(var) => self.var_io_map.get(var.index()).map(|x| x[0]), + /// This will panic if the wire is not in the circuit + pub fn nodes_on_wire(&self, wire: Wire) -> impl Iterator { + let start_node = match wire { + Wire::Qubit(qubit) => self.qubit_io_map[qubit.index()][0], + Wire::Clbit(clbit) => self.clbit_io_map[clbit.index()][0], + Wire::Var(var) => self.var_io_map[var.index()][0], }; - - while let Some(node) = current_node { - if only_ops { - let node_weight = self.dag.node_weight(node).unwrap(); - if let NodeType::Operation(_) = node_weight { - nodes.push(node); - } - } else { - nodes.push(node); - } - - let edges = self.dag.edges_directed(node, Outgoing); - current_node = edges - .into_iter() - .find_map(|edge| (*edge.weight() == wire).then_some(edge.target())); + NodesOnWireIter { + dag: self, + wire, + next_node: Some(start_node), } - nodes } fn remove_idle_wire(&mut self, wire: Wire) { @@ -8117,6 +8107,40 @@ impl DAGCircuit { } } +struct NodesOnWireIter<'a> { + dag: &'a DAGCircuit, + wire: Wire, + next_node: Option, +} + +impl<'a> Iterator for NodesOnWireIter<'a> { + type Item = NodeIndex; + + fn size_hint(&self) -> (usize, Option) { + // If the wire is empty there are the input and output nodes which is the minimum + // If the dag is all on a single wire we would have all the operations + the 2 io + // nodes. + (2, Some(self.dag.num_ops() + 2)) + } + + fn next(&mut self) -> Option { + let out_node = self.next_node?; + self.next_node = self + .dag + .dag + .edges_directed(out_node, Outgoing) + .filter_map(|e: EdgeReference<'a, Wire>| { + if e.weight() == &self.wire { + Some(e.target()) + } else { + None + } + }) + .next(); + Some(out_node) + } +} + pub struct DAGCircuitBuilder { dag: DAGCircuit, last_clbits: Vec>, diff --git a/crates/transpiler/src/passes/commutation_analysis.rs b/crates/transpiler/src/passes/commutation_analysis.rs index beb0d7818869..8c87bb616e7d 100644 --- a/crates/transpiler/src/passes/commutation_analysis.rs +++ b/crates/transpiler/src/passes/commutation_analysis.rs @@ -61,7 +61,7 @@ pub fn analyze_commutations( for qubit in 0..dag.num_qubits() { let wire = Wire::Qubit(Qubit(qubit as u32)); - for current_gate_idx in dag.nodes_on_wire(wire, false) { + for current_gate_idx in dag.nodes_on_wire(wire) { // get the commutation set associated with the current wire, or create a new // index set containing the current gate let commutation_entry = commutation_set