diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index cd0938efe853..e2654f488e01 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -6,12 +6,12 @@ from functools import cached_property from operator import itemgetter from typing import TYPE_CHECKING, Any -from urllib.parse import unquote_plus, urlparse +from urllib.parse import unquote_plus import sqlglot as sg import sqlglot.expressions as sge import trino -from trino.auth import BasicAuthentication +from trino.auth import Authentication, BasicAuthentication import ibis import ibis.backends.sql.compilers as sc @@ -60,7 +60,10 @@ def _from_url(self, url: ParseResult, **kwarg_overrides): if url.password: kwargs["auth"] = unquote_plus(url.password) if url.hostname: - kwargs["host"] = url.hostname + # Do NOT convert to url.hostname, the trino client expects an entire URL + # to do inference on http vs https, port, etc. + # https://github.com/trinodb/trino-python-client/blob/2108c38dea79518ffb74370177df2dc95f1e6d96/trino/dbapi.py#L169 + kwargs["host"] = url if database: kwargs["database"] = database if url.port: @@ -245,11 +248,12 @@ def list_tables( def do_connect( self, user: str = "user", - auth: str | None = None, + auth: str | Authentication | None = None, host: str = "localhost", port: int = 8080, database: str | None = None, schema: str | None = None, + *, source: str | None = None, timezone: str = "UTC", **kwargs, @@ -261,7 +265,7 @@ def do_connect( user Username to connect with auth - Authentication method or password to use for the connection. + Password or authentication method to use for the connection. host Hostname of the Trino server port @@ -296,11 +300,7 @@ def do_connect( >>> con = ibis.trino.connect(database=catalog, schema=schema) >>> con = ibis.trino.connect(database=catalog, schema=schema, source="my-app") """ - if ( - isinstance(auth, str) - and (scheme := urlparse(host).scheme) - and scheme != "http" - ): + if isinstance(auth, str): auth = BasicAuthentication(user, auth) self.con = trino.dbapi.connect( diff --git a/ibis/backends/trino/tests/test_client.py b/ibis/backends/trino/tests/test_client.py index 08a03ec6a9f7..7d1fbae64806 100644 --- a/ibis/backends/trino/tests/test_client.py +++ b/ibis/backends/trino/tests/test_client.py @@ -68,6 +68,37 @@ def test_list_databases(con): assert {"information_schema", "sf1"}.issubset(con.list_databases(catalog="tpch")) +@pytest.mark.parametrize("auth_type", ["str", "obj"]) +@pytest.mark.parametrize(("http_scheme"), ["http", "https", None]) +def test_con_auth_str(http_scheme, auth_type): + """If a non-str object is given to auth, it should be used as is, no matter what. + If a str is given to auth, it should be converted to a BasicAuthentication. + This happens no matter the http_scheme given. + + Tests for + https://github.com/ibis-project/ibis/issues/9113 + https://github.com/ibis-project/ibis/issues/9956 + https://github.com/ibis-project/ibis/issues/11841 + """ + from trino.auth import BasicAuthentication + + if auth_type == "str": + auth = TRINO_PASS + else: + auth = BasicAuthentication(TRINO_USER, TRINO_PASS) + + con = ibis.trino.connect( + user=TRINO_USER, + host=TRINO_HOST, + port=TRINO_PORT, + auth=auth, + database="hive", + schema="default", + http_scheme=http_scheme, + ) + assert con.con.auth == BasicAuthentication(TRINO_USER, TRINO_PASS) + + @pytest.mark.parametrize(("source", "expected"), [(None, "ibis"), ("foo", "foo")]) def test_con_source(source, expected): con = ibis.trino.connect(