diff --git a/commitizen/changelog.py b/commitizen/changelog.py index ba6fbbc6b..4920b75a5 100644 --- a/commitizen/changelog.py +++ b/commitizen/changelog.py @@ -32,6 +32,7 @@ from collections.abc import Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from datetime import date +from itertools import chain, tee from typing import TYPE_CHECKING, Any from jinja2 import ( @@ -281,6 +282,16 @@ def incremental_build( return output_lines +def get_next_tag_name_after_version(tags: Iterable[GitTag], version: str) -> str | None: + a, b = tee(chain((tag.name for tag in tags), [None])) + next(b, None) + try: + return next(y for x, y in zip(a, b) if x == version) + except StopIteration: + raise NoCommitsFoundError(f"Could not find a valid revision range. {version=}") + + +# TODO: unused, deprecate this? def get_smart_tag_range( tags: Sequence[GitTag], newest: str, oldest: str | None = None ) -> list[GitTag]: @@ -308,7 +319,7 @@ def get_smart_tag_range( def get_oldest_and_newest_rev( - tags: Sequence[GitTag], + tags: Iterable[GitTag], version: str, rules: TagRules, ) -> tuple[str | None, str]: @@ -318,39 +329,26 @@ def get_oldest_and_newest_rev( - `0.1.0..0.4.0`: as a range - `0.3.0`: as a single version """ - oldest: str | None = None - newest: str | None = None - try: - oldest, newest = version.split("..") - except ValueError: - newest = version - if not (newest_tag := rules.find_tag_for(tags, newest)): + oldest_version, sep, newest_version = version.partition("..") + if not sep: + newest_version = version + oldest_version = "" + + def get_tag_name(v: str) -> str: + if tag := rules.find_tag_for(tags, v): + return tag.name raise NoCommitsFoundError("Could not find a valid revision range.") - oldest_tag = None - oldest_tag_name = None - if oldest: - if not (oldest_tag := rules.find_tag_for(tags, oldest)): - raise NoCommitsFoundError("Could not find a valid revision range.") - oldest_tag_name = oldest_tag.name + newest_tag_name = get_tag_name(newest_version) + oldest_tag_name = get_tag_name(oldest_version) if oldest_version else None - tags_range = get_smart_tag_range( - tags, newest=newest_tag.name, oldest=oldest_tag_name + oldest_rev = get_next_tag_name_after_version( + tags, oldest_tag_name or newest_tag_name ) - if not tags_range: - raise NoCommitsFoundError("Could not find a valid revision range.") - - oldest_rev: str | None = tags_range[-1].name - newest_rev = newest_tag.name - - # check if it's the first tag created - # and it's also being requested as part of the range - if oldest_rev == tags[-1].name and oldest_rev == oldest_tag_name: - return None, newest_rev - - # when they are the same, and it's also the - # first tag created - if oldest_rev == newest_rev: - return None, newest_rev - return oldest_rev, newest_rev + # Return None for oldest_rev if: + # 1. The oldest tag is the last tag in the list and matches the requested oldest tag + # 2. The oldest and the newest tag are the same + if oldest_rev == newest_tag_name: + return None, newest_tag_name + return oldest_rev, newest_tag_name diff --git a/tests/test_changelog.py b/tests/test_changelog.py index ed90ed08e..4465fcccb 100644 --- a/tests/test_changelog.py +++ b/tests/test_changelog.py @@ -1535,6 +1535,24 @@ def test_get_smart_tag_range_returns_an_extra_for_a_single_tag(tags): assert 2 == len(res) +def test_get_next_tag_name_after_version(tags): + # Test finding next tag after a version + next_tag_name = changelog.get_next_tag_name_after_version(tags, "v1.2.0") + assert next_tag_name == "v1.1.1" + + next_tag_name = changelog.get_next_tag_name_after_version(tags, "v1.1.0") + assert next_tag_name == "v1.0.0" + + # Test finding last tag when given version is last + last_tag_name = changelog.get_next_tag_name_after_version(tags, "v0.9.1") + assert last_tag_name is None + + # Test error when version not found + with pytest.raises(changelog.NoCommitsFoundError) as exc_info: + changelog.get_next_tag_name_after_version(tags, "nonexistent") + assert "Could not find a valid revision range" in str(exc_info.value) + + @dataclass class TagDef: name: str