Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,38 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
)


def get_tag_for(id_: int, tag_: str, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
SELECT *
FROM dataset_tag
WHERE `id` = :dataset_id AND `tag` = :tag
""",
),
parameters={
"dataset_id": id_,
"tag": tag_,
},
)
return row.one_or_none()


def untag(id_: int, tag_: str, *, connection: Connection) -> None:
connection.execute(
text(
"""
DELETE FROM dataset_tag
WHERE `id` = :dataset_id AND `tag` = :tag
""",
),
parameters={
"dataset_id": id_,
"tag": tag_,
},
)


def get_description(
id_: int,
connection: Connection,
Expand Down
39 changes: 39 additions & 0 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,52 @@ def tag_dataset(
}


@router.post(
path="/untag",
)
def untag_dataset(
data_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User | None, Depends(fetch_user)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> dict[str, dict[str, Any]]:
if user is None:
raise create_authentication_failed_error()

tag_record = database.datasets.get_tag_for(data_id, tag, expdb_db)
if tag_record is None:
raise create_tag_not_found_error()

if tag_record.uploader != user.user_id and UserGroup.ADMIN not in user.groups:
raise create_tag_not_owned_error()

database.datasets.untag(data_id, tag, connection=expdb_db)
return {
"data_untag": {"id": str(data_id)},
}


def create_authentication_failed_error() -> HTTPException:
return HTTPException(
status_code=HTTPStatus.PRECONDITION_FAILED,
detail={"code": "103", "message": "Authentication failed"},
)


def create_tag_not_found_error() -> HTTPException:
return HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={"code": "475", "message": "Tag not found."},
)


def create_tag_not_owned_error() -> HTTPException:
return HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail={"code": "476", "message": "Tag is not owned by you"},
)


def create_tag_exists_error(data_id: int, tag: str) -> HTTPException:
return HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
Expand Down
83 changes: 83 additions & 0 deletions tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,86 @@ def test_dataset_tag_invalid_tag_is_rejected(

assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert new.json()["detail"][0]["loc"] == ["body", "tag"]


@pytest.mark.parametrize(
"key",
[None, ApiKey.INVALID],
ids=["no authentication", "invalid key"],
)
def test_dataset_untag_rejects_unauthorized(key: ApiKey, py_api: TestClient) -> None:
apikey = "" if key is None else f"?api_key={key}"
response = py_api.post(
f"/datasets/untag{apikey}",
json={"data_id": 1, "tag": "study_14"},
)
assert response.status_code == HTTPStatus.PRECONDITION_FAILED
assert response.json()["detail"] == {"code": "103", "message": "Authentication failed"}


def test_dataset_untag(py_api: TestClient, expdb_test: Connection) -> None:
dataset_id = 1
tag = "temp_dataset_untag"
py_api.post(
f"/datasets/tag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)

response = py_api.post(
f"/datasets/untag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"data_untag": {"id": str(dataset_id)}}
assert tag not in get_tags_for(id_=dataset_id, connection=expdb_test)


def test_dataset_untag_rejects_other_user(py_api: TestClient) -> None:
dataset_id = 1
tag = "temp_dataset_untag_not_owned"
py_api.post(
f"/datasets/tag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)

response = py_api.post(
f"/datasets/untag?api_key={ApiKey.OWNER_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
assert response.json()["detail"] == {"code": "476", "message": "Tag is not owned by you"}

cleanup = py_api.post(
f"/datasets/untag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert cleanup.status_code == HTTPStatus.OK


def test_dataset_untag_fails_if_tag_does_not_exist(py_api: TestClient) -> None:
dataset_id = 1
tag = "definitely_not_a_dataset_tag"
response = py_api.post(
f"/datasets/untag?api_key={ApiKey.ADMIN}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
assert response.json()["detail"] == {"code": "475", "message": "Tag not found."}


@pytest.mark.parametrize(
"tag",
["", "h@", " a", "a" * 65],
ids=["too short", "@", "space", "too long"],
)
def test_dataset_untag_invalid_tag_is_rejected(
tag: str,
py_api: TestClient,
) -> None:
response = py_api.post(
f"/datasets/untag?api_key={ApiKey.ADMIN}",
json={"data_id": 1, "tag": tag},
)

assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json()["detail"][0]["loc"] == ["body", "tag"]
43 changes: 43 additions & 0 deletions tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,49 @@ def test_dataset_tag_response_is_identical(
assert original == new


@pytest.mark.parametrize(
"dataset_id",
[1, 2, 3, 101, 131],
)
@pytest.mark.parametrize(
"api_key",
[ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER],
ids=["Administrator", "regular user", "possible owner"],
)
@pytest.mark.parametrize(
"tag",
["study_14", "study_15"],
)
def test_dataset_untag_response_is_identical(
dataset_id: int,
tag: str,
api_key: str,
py_api: TestClient,
php_api: httpx.Client,
) -> None:
original = php_api.post(
"/data/untag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
if original.status_code == HTTPStatus.OK:
php_api.post(
"/data/tag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)

new = py_api.post(
f"/datasets/untag?api_key={api_key}",
json={"data_id": dataset_id, "tag": tag},
)

assert original.status_code == new.status_code, original.json()
if new.status_code != HTTPStatus.OK:
assert original.json()["error"] == new.json()["detail"]
return

assert original.json() == new.json()


@pytest.mark.parametrize(
"data_id",
list(range(1, 130)),
Expand Down