diff --git a/ibis/backends/databricks/__init__.py b/ibis/backends/databricks/__init__.py index 24bbf55de565..bc38d1c142a0 100644 --- a/ibis/backends/databricks/__init__.py +++ b/ibis/backends/databricks/__init__.py @@ -254,7 +254,9 @@ def create_table( return self.table(name, database=(catalog, database)) - def table(self, name: str, /, *, database: str | None = None) -> ir.Table: + def table( + self, name: str, /, *, database: str | tuple[str, str] | None = None + ) -> ir.Table: table_loc = self._to_sqlglot_table(database) # TODO: set these to better defaults @@ -500,15 +502,17 @@ def finalizer(path: str = path, con=self.con) -> None: def create_database( self, name: str, /, *, catalog: str | None = None, force: bool = False ) -> None: - name = sg.table(name, catalog=catalog, quoted=self.compiler.quoted) - with self._safe_raw_sql(sge.Create(this=name, kind="SCHEMA", replace=force)): + full_name = sg.table(name, catalog=catalog, quoted=self.compiler.quoted) + with self._safe_raw_sql( + sge.Create(this=full_name, kind="SCHEMA", replace=force) + ): pass def drop_database( self, name: str, /, *, catalog: str | None = None, force: bool = False ) -> None: - name = sg.table(name, catalog=catalog, quoted=self.compiler.quoted) - with self._safe_raw_sql(sge.Drop(this=name, kind="SCHEMA", replace=force)): + full_name = sg.table(name, catalog=catalog, quoted=self.compiler.quoted) + with self._safe_raw_sql(sge.Drop(this=full_name, kind="SCHEMA", replace=force)): pass def list_tables( @@ -531,7 +535,7 @@ def to_pyarrow_batches( /, *, params: Mapping[ir.Scalar, Any] | None = None, - limit: int | str | None = None, + limit: int | None = None, chunk_size: int = 1_000_000, **_: Any, ) -> pa.ipc.RecordBatchReader: @@ -573,7 +577,7 @@ def to_pyarrow( /, *, params: Mapping[ir.Scalar, Any] | None = None, - limit: int | str | None = None, + limit: int | None = None, **kwargs: Any, ) -> pa.Table: self._run_pre_execute_hooks(expr) @@ -592,7 +596,7 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: if (table := cursor.fetchall_arrow()) is None: table = schema.to_pyarrow().empty_table() df = table.to_pandas(timestamp_as_object=True) - df.columns = list(schema.names) + df.columns = list(schema) return df def _get_schema_using_query(self, query: str) -> sch.Schema: