""" Extended tests for retry_helper.py — covers decorator, functional API, and error classification. """ import pytest from unittest.mock import patch, MagicMock from retry_helper import ( retry_with_backoff, retry_on_failure, safe_execute, is_retryable_error, RetryableError, NonRetryableError, ) class TestRetryWithBackoff: def test_succeeds_first_try(self): @retry_with_backoff(max_retries=3, initial_delay=0.01) def good_func(): return "ok" assert good_func() == "ok" def test_retries_then_succeeds(self): attempts = [0] @retry_with_backoff(max_retries=3, initial_delay=0.01) def flaky(): attempts[0] += 1 if attempts[0] < 3: raise ConnectionError("fail") return "recovered" assert flaky() == "recovered" assert attempts[0] == 3 def test_exhausts_retries(self): @retry_with_backoff(max_retries=2, initial_delay=0.01) def always_fail(): raise ValueError("permanent") with pytest.raises(ValueError, match="permanent"): always_fail() def test_specific_exception_filter(self): @retry_with_backoff(max_retries=2, initial_delay=0.01, exceptions=(ConnectionError,)) def wrong_exception(): raise TypeError("not retryable") with pytest.raises(TypeError): wrong_exception() def test_respects_max_delay(self): attempts = [0] @retry_with_backoff(max_retries=2, initial_delay=0.01, max_delay=0.02) def slow_fail(): attempts[0] += 1 if attempts[0] <= 2: raise ConnectionError("fail") return "ok" assert slow_fail() == "ok" def test_preserves_function_name(self): @retry_with_backoff(max_retries=1, initial_delay=0.01) def my_special_func(): """My docstring.""" return True assert my_special_func.__name__ == "my_special_func" assert "My docstring" in my_special_func.__doc__ class TestRetryOnFailure: def test_function_succeeds(self): result = retry_on_failure(lambda: 42, max_retries=1, initial_delay=0.01) assert result == 42 def test_function_retries_and_fails(self): def always_fail(): raise RuntimeError("boom") with pytest.raises(RuntimeError): retry_on_failure(always_fail, max_retries=1, initial_delay=0.01) class TestSafeExecute: def test_success_returns_value(self): result = safe_execute(lambda: "hello", fallback_value="default") assert result == "hello" def test_failure_returns_fallback(self): def fail(): raise Exception("crash") result = safe_execute(fail, fallback_value="safe") assert result == "safe" def test_failure_returns_none_default(self): def fail(): raise Exception("crash") result = safe_execute(fail) assert result is None def test_failure_logs_when_enabled(self): def fail(): raise ValueError("logged") with patch("retry_helper.logger") as mock_logger: safe_execute(fail, log_errors=True) mock_logger.warning.assert_called_once() def test_failure_silent_when_disabled(self): def fail(): raise ValueError("silent") with patch("retry_helper.logger") as mock_logger: safe_execute(fail, log_errors=False) mock_logger.warning.assert_not_called() class TestIsRetryableError: def test_retryable_error_class(self): assert is_retryable_error(RetryableError("retry me")) is True def test_non_retryable_error_class(self): assert is_retryable_error(NonRetryableError("no retry")) is False def test_timeout_error(self): assert is_retryable_error(Exception("Connection timeout")) is True def test_connection_error(self): assert is_retryable_error(Exception("connection refused")) is True def test_rate_limit_error(self): assert is_retryable_error(Exception("rate limit exceeded")) is True def test_429_error(self): assert is_retryable_error(Exception("HTTP 429 Too Many Requests")) is True def test_503_error(self): assert is_retryable_error(Exception("503 Service Unavailable")) is True def test_generic_error_not_retryable(self): assert is_retryable_error(ValueError("invalid input")) is False def test_temporary_error(self): assert is_retryable_error(Exception("temporary failure")) is True class TestExceptionClasses: def test_retryable_error_is_exception(self): assert issubclass(RetryableError, Exception) def test_non_retryable_error_is_exception(self): assert issubclass(NonRetryableError, Exception) def test_retryable_error_message(self): e = RetryableError("test message") assert str(e) == "test message" if __name__ == "__main__": pytest.main([__file__, "-v"])