diff --git a/commitizen/changelog_formats/__init__.py b/commitizen/changelog_formats/__init__.py index b7b3cac01..9a5eea7ab 100644 --- a/commitizen/changelog_formats/__init__.py +++ b/commitizen/changelog_formats/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import ClassVar, Protocol +from typing import Callable, ClassVar, Protocol if sys.version_info >= (3, 10): from importlib import metadata @@ -64,10 +64,9 @@ def get_changelog_format( :raises FormatUnknown: if a non-empty name is provided but cannot be found in the known formats """ name: str | None = config.settings.get("changelog_format") - format: type[ChangelogFormat] | None = guess_changelog_format(filename) - - if name and name in KNOWN_CHANGELOG_FORMATS: - format = KNOWN_CHANGELOG_FORMATS[name] + format = ( + name and KNOWN_CHANGELOG_FORMATS.get(name) or _guess_changelog_format(filename) + ) if not format: raise ChangelogFormatUnknown(f"Unknown changelog format '{name}'") @@ -75,7 +74,7 @@ def get_changelog_format( return format(config) -def guess_changelog_format(filename: str | None) -> type[ChangelogFormat] | None: +def _guess_changelog_format(filename: str | None) -> type[ChangelogFormat] | None: """ Try guessing the file format from the filename. @@ -91,3 +90,9 @@ def guess_changelog_format(filename: str | None) -> type[ChangelogFormat] | None if filename.endswith(f".{alt_extension}"): return format return None + + +def __getattr__(name: str) -> Callable[[str], type[ChangelogFormat] | None]: + if name == "guess_changelog_format": + return _guess_changelog_format + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/commitizen/factory.py b/commitizen/factory.py index b5d665b65..d9e99fb77 100644 --- a/commitizen/factory.py +++ b/commitizen/factory.py @@ -8,12 +8,10 @@ def committer_factory(config: BaseConfig) -> BaseCommitizen: """Return the correct commitizen existing in the registry.""" name: str = config.settings["name"] try: - _cz = registry[name](config) + return registry[name](config) except KeyError: msg_error = ( "The committer has not been found in the system.\n\n" f"Try running 'pip install {name}'\n" ) raise NoCommitizenFoundException(msg_error) - else: - return _cz diff --git a/tests/test_changelog_formats.py b/tests/test_changelog_formats.py index dec23720d..e0d99e032 100644 --- a/tests/test_changelog_formats.py +++ b/tests/test_changelog_formats.py @@ -6,8 +6,8 @@ from commitizen.changelog_formats import ( KNOWN_CHANGELOG_FORMATS, ChangelogFormat, + _guess_changelog_format, get_changelog_format, - guess_changelog_format, ) from commitizen.config.base_config import BaseConfig from commitizen.exceptions import ChangelogFormatUnknown @@ -15,14 +15,14 @@ @pytest.mark.parametrize("format", KNOWN_CHANGELOG_FORMATS.values()) def test_guess_format(format: type[ChangelogFormat]): - assert guess_changelog_format(f"CHANGELOG.{format.extension}") is format + assert _guess_changelog_format(f"CHANGELOG.{format.extension}") is format for ext in format.alternative_extensions: - assert guess_changelog_format(f"CHANGELOG.{ext}") is format + assert _guess_changelog_format(f"CHANGELOG.{ext}") is format @pytest.mark.parametrize("filename", ("CHANGELOG", "NEWS", "file.unknown", None)) def test_guess_format_unknown(filename: str): - assert guess_changelog_format(filename) is None + assert _guess_changelog_format(filename) is None @pytest.mark.parametrize( diff --git a/tests/test_defaults.py b/tests/test_deprecated.py similarity index 84% rename from tests/test_defaults.py rename to tests/test_deprecated.py index 73cd35b80..41bea81a7 100644 --- a/tests/test_defaults.py +++ b/tests/test_deprecated.py @@ -1,6 +1,6 @@ import pytest -from commitizen import defaults +from commitizen import changelog_formats, defaults def test_getattr_deprecated_vars(): @@ -15,6 +15,10 @@ def test_getattr_deprecated_vars(): assert defaults.change_type_order == defaults.CHANGE_TYPE_ORDER assert defaults.encoding == defaults.ENCODING assert defaults.name == defaults.DEFAULT_SETTINGS["name"] + assert ( + changelog_formats._guess_changelog_format + == changelog_formats.guess_changelog_format + ) # Verify warning messages assert len(record) == 7