Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions ibis/backends/risingwave/tests/test_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import pytest

import ibis
from ibis import _


def test_tumble_window_by_grouped_agg(alltypes):
t = alltypes
expr = (
t.window_by(t.timestamp_col)
.tumble(size=ibis.interval(days=10))
.agg(by=["string_col"], avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"]
assert result.shape == (740, 4)


def test_tumble_window_by_ungrouped_agg(alltypes):
t = alltypes
expr = (
t.window_by(t.timestamp_col)
.tumble(size=ibis.interval(days=1))
.agg(avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "avg"]
assert result.shape == (730, 3)


def test_hop_window_by_grouped_agg(alltypes):
t = alltypes
expr = (
t.window_by(t.timestamp_col)
.hop(size=ibis.interval(days=10), slide=ibis.interval(days=10))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should use a different slide so that it's not essentially a hopping window? :) I'm sure it works, but even reviewing at a glance I just noticed the results are the same.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using size=20 days, but I couldn't quite understand the resulting window_start/end produced by RisingWave. So I changed back to use the same size and slide to ensure I would't introduce wrong results.

.agg(by=["string_col"], avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"]
assert result.shape == (740, 4)


def test_hop_window_by_ungrouped_agg(alltypes):
t = alltypes
expr = (
t.window_by(t.timestamp_col)
.hop(size=ibis.interval(days=1), slide=ibis.interval(days=1))
.agg(avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "avg"]
assert result.shape == (730, 3)
48 changes: 48 additions & 0 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,53 @@ def visit_MapContains(self, op, *, arg, key):
self.cast(arg, op.arg.dtype), self.cast(key, op.key.dtype)
)

def visit_WindowAggregate(
self,
op,
*,
parent,
window_type,
time_col,
groups,
metrics,
window_size,
window_slide,
window_offset,
):
if window_type == "tumble":
assert window_slide is None

args = [
self.v[parent.this.sql(self.dialect)],
time_col.this,
window_slide,
window_size,
window_offset,
]

window_func = getattr(self.f, window_type)

# create column references to new columns generated by backend in the output
window_start = sg.column(
"window_start", table=parent.alias_or_name, quoted=True
)
window_end = sg.column("window_end", table=parent.alias_or_name, quoted=True)

return (
sg.select(
window_start,
window_end,
*self._cleanup_names(groups),
*self._cleanup_names(metrics),
copy=False,
)
.from_(
window_func(*filter(None, args)).as_(parent.alias_or_name, quoted=True)
)
.group_by(
*self._generate_groups([window_start, window_end, *groups.values()])
)
)


compiler = RisingWaveCompiler()
Loading