diff --git a/Base.Tests/Extensions/EnumerableExtensionTests.cs b/Base.Tests/Extensions/EnumerableExtensionTests.cs index 73b7348..1b0f0e7 100644 --- a/Base.Tests/Extensions/EnumerableExtensionTests.cs +++ b/Base.Tests/Extensions/EnumerableExtensionTests.cs @@ -15,4 +15,31 @@ public class EnumerableExtensionTests [TestCase(new object[] { "hi", "there" }, false, "there", ExpectedResult = new object[] { "hi", "there" })] public object[] TestConditionalWhere(object[] input, bool isConditionValid, object valueToKeep) => input.ConditionalWhere(isConditionValid, x => x.Equals(valueToKeep)).ToArray(); + + [Test] + public void TestChooseWithRefType() + { + var input = new [] { "one", null, "two" }; + + var result = input.Choose(x => x).ToArray(); + + Assert.That(result.GetType(), Is.EqualTo(typeof(string[]))); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0], Is.EqualTo("one")); + Assert.That(result[1], Is.EqualTo("two")); + } + + [Test] + public void TestChooseWithValueType() + { + var input = new int?[] { 1, null, 2 }; + + var result = input.Choose(x => x).ToArray(); + + Assert.That(result.GetType(), Is.EqualTo(typeof(int[]))); + Assert.That(result.GetType(), Is.Not.EqualTo(typeof(int?[]))); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0], Is.EqualTo(1)); + Assert.That(result[1], Is.EqualTo(2)); + } } \ No newline at end of file diff --git a/Base/Extensions/EnumerableExtensions.cs b/Base/Extensions/EnumerableExtensions.cs index ef0eede..8481e77 100644 --- a/Base/Extensions/EnumerableExtensions.cs +++ b/Base/Extensions/EnumerableExtensions.cs @@ -12,9 +12,17 @@ public static class EnumerableExtensions public static IEnumerable ConditionalWhere(this IEnumerable enumerable, bool isConditionValid, Func pred) => !isConditionValid ? enumerable : enumerable.Where(pred); - public static IEnumerable Choose(this IEnumerable enumerable, Func mapper) => + public static IEnumerable Choose(this IEnumerable enumerable, Func mapper) + where TInput : class? => enumerable .Select(mapper) .Where(x => x is not null) .Cast(); + + public static IEnumerable Choose(this IEnumerable enumerable, Func mapper) + where TOutput : struct => + enumerable + .Select(mapper) + .Where(x => x.HasValue) + .Select(x => x!.Value); } \ No newline at end of file