diff --git a/ibis/backends/risingwave/tests/test_window.py b/ibis/backends/risingwave/tests/test_window.py new file mode 100644 index 000000000000..b6e85fd4723f --- /dev/null +++ b/ibis/backends/risingwave/tests/test_window.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import ibis +from ibis import _ + + +def test_tumble_window_by_grouped_agg(alltypes): + t = alltypes + expr = ( + t.window_by(t.timestamp_col) + .tumble(size=ibis.interval(days=10)) + .agg(by=["string_col"], avg=_.float_col.mean()) + ) + result = expr.to_pandas() + assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"] + assert result.shape == (740, 4) + + +def test_tumble_window_by_ungrouped_agg(alltypes): + t = alltypes + expr = ( + t.window_by(t.timestamp_col) + .tumble(size=ibis.interval(days=1)) + .agg(avg=_.float_col.mean()) + ) + result = expr.to_pandas() + assert list(result.columns) == ["window_start", "window_end", "avg"] + assert result.shape == (730, 3) + + +def test_hop_window_by_grouped_agg(alltypes): + t = alltypes + expr = ( + t.window_by(t.timestamp_col) + .hop(size=ibis.interval(days=10), slide=ibis.interval(days=10)) + .agg(by=["string_col"], avg=_.float_col.mean()) + ) + result = expr.to_pandas() + assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"] + assert result.shape == (740, 4) + + +def test_hop_window_by_ungrouped_agg(alltypes): + t = alltypes + expr = ( + t.window_by(t.timestamp_col) + .hop(size=ibis.interval(days=1), slide=ibis.interval(days=1)) + .agg(avg=_.float_col.mean()) + ) + result = expr.to_pandas() + assert list(result.columns) == ["window_start", "window_end", "avg"] + assert result.shape == (730, 3) diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index bbce332063fe..656a28cbbf04 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -194,5 +194,53 @@ def visit_MapContains(self, op, *, arg, key): self.cast(arg, op.arg.dtype), self.cast(key, op.key.dtype) ) + def visit_WindowAggregate( + self, + op, + *, + parent, + window_type, + time_col, + groups, + metrics, + window_size, + window_slide, + window_offset, + ): + if window_type == "tumble": + assert window_slide is None + + args = [ + self.v[parent.this.sql(self.dialect)], + time_col.this, + window_slide, + window_size, + window_offset, + ] + + window_func = getattr(self.f, window_type) + + # create column references to new columns generated by backend in the output + window_start = sg.column( + "window_start", table=parent.alias_or_name, quoted=True + ) + window_end = sg.column("window_end", table=parent.alias_or_name, quoted=True) + + return ( + sg.select( + window_start, + window_end, + *self._cleanup_names(groups), + *self._cleanup_names(metrics), + copy=False, + ) + .from_( + window_func(*filter(None, args)).as_(parent.alias_or_name, quoted=True) + ) + .group_by( + *self._generate_groups([window_start, window_end, *groups.values()]) + ) + ) + compiler = RisingWaveCompiler()