fix: handle IEnumerable.Choose<Nullable<TStruct>> correctly

This commit is contained in:
Kyle Ratti 2022-12-23 13:39:33 -05:00
parent 69bc12e01f
commit c17d3f10d4
No known key found for this signature in database
GPG Key ID: 4D429B6287C68DD9
2 changed files with 36 additions and 1 deletions

View File

@ -15,4 +15,31 @@ public class EnumerableExtensionTests
[TestCase(new object[] { "hi", "there" }, false, "there", ExpectedResult = new object[] { "hi", "there" })]
public object[] TestConditionalWhere(object[] input, bool isConditionValid, object valueToKeep) =>
input.ConditionalWhere(isConditionValid, x => x.Equals(valueToKeep)).ToArray();
[Test]
public void TestChooseWithRefType()
{
var input = new [] { "one", null, "two" };
var result = input.Choose(x => x).ToArray();
Assert.That(result.GetType(), Is.EqualTo(typeof(string[])));
Assert.That(result.Length, Is.EqualTo(2));
Assert.That(result[0], Is.EqualTo("one"));
Assert.That(result[1], Is.EqualTo("two"));
}
[Test]
public void TestChooseWithValueType()
{
var input = new int?[] { 1, null, 2 };
var result = input.Choose(x => x).ToArray();
Assert.That(result.GetType(), Is.EqualTo(typeof(int[])));
Assert.That(result.GetType(), Is.Not.EqualTo(typeof(int?[])));
Assert.That(result.Length, Is.EqualTo(2));
Assert.That(result[0], Is.EqualTo(1));
Assert.That(result[1], Is.EqualTo(2));
}
}

View File

@ -12,9 +12,17 @@ public static class EnumerableExtensions
public static IEnumerable<T> ConditionalWhere<T>(this IEnumerable<T> enumerable, bool isConditionValid, Func<T, bool> pred) =>
!isConditionValid ? enumerable : enumerable.Where(pred);
public static IEnumerable<TOutput> Choose<TInput, TOutput>(this IEnumerable<TInput> enumerable, Func<TInput, TOutput?> mapper) =>
public static IEnumerable<TOutput> Choose<TInput, TOutput>(this IEnumerable<TInput> enumerable, Func<TInput, TOutput?> mapper)
where TInput : class? =>
enumerable
.Select(mapper)
.Where(x => x is not null)
.Cast<TOutput>();
public static IEnumerable<TOutput> Choose<TInput, TOutput>(this IEnumerable<TInput> enumerable, Func<TInput, TOutput?> mapper)
where TOutput : struct =>
enumerable
.Select(mapper)
.Where(x => x.HasValue)
.Select(x => x!.Value);
}