diff --git a/src/main/java/com/mojang/serialization/codecs/CollectCodec.java b/src/main/java/com/mojang/serialization/codecs/CollectCodec.java new file mode 100644 index 00000000..a50e8b96 --- /dev/null +++ b/src/main/java/com/mojang/serialization/codecs/CollectCodec.java @@ -0,0 +1,186 @@ +package com.mojang.serialization.codecs; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collector; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import com.mojang.datafixers.util.Pair; +import com.mojang.datafixers.util.Unit; +import com.mojang.serialization.Codec; +import com.mojang.serialization.DataResult; +import com.mojang.serialization.DynamicOps; +import com.mojang.serialization.Lifecycle; +import com.mojang.serialization.ListBuilder; +import org.apache.commons.lang3.mutable.MutableObject; + +/** + * A codec for any types that have a {@link Collector} + * + * @see Collector + * @param the type of input elements to the reduction operation + * @param the mutable accumulation type of the reduction operation (often hidden as an implementation detail) + * @param the result type of the reduction operation + */ +public final class CollectCodec implements Codec { + public static Codec of(Collector collector, Function> iteratorFunction, Codec element) { + return new CollectCodec<>(collector, iteratorFunction, element); + } + + public static > Codec of(Collector collector, Codec element) { + return new CollectCodec<>(collector, Iterable::iterator, element); + } + + /** + * @see Collectors#toCollection(Supplier) + */ + public static , T> Codec collection(Supplier supplier, Codec elementCodec) { + return of(Collectors.toCollection(supplier), elementCodec); + } + + /** + * A codec for an immutable list + */ + public static Codec> list(Codec element) { + return of(Collectors.toUnmodifiableList(), element); + } + + /** + * A codec for an ArrayList + */ + public static Codec> arrayList(Codec element) { + return collection(ArrayList::new, element); + } + + /** + * A codec for an immutable set + */ + public static Codec> set(Codec element) { + return of(Collectors.toUnmodifiableSet(), element); + } + + public static Codec> hashSet(Codec element) { + return collection(HashSet::new, element); + } + + public static Codec> concurrentSet(Codec element) { + return collection(ConcurrentHashMap::newKeySet, element); + } + + /** + * A codec for an immutable map + */ + public static Codec> map(Codec key, Codec value) { + return of(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue), m -> m.entrySet().iterator(), MapEntryCodec.of(key, value)); + } + + public static Codec> map(Supplier> supplier, Codec key, Codec value) { + return of(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> { + throw new IllegalStateException("Key Conflict " + a + " & " + b); + }, supplier), m -> m.entrySet().iterator(), MapEntryCodec.of(key, value)); + } + + public static Codec> hashMap(Codec key, Codec value) { + return map(HashMap::new, key, value); + } + + public static Codec> concurrentMap(Supplier> supplier, Codec key, Codec value) { + return of(Collectors.toConcurrentMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> { + throw new IllegalStateException("Key Conflict " + a + " & " + b); + }, supplier), m -> m.entrySet().iterator(), MapEntryCodec.of(key, value)); + } + + public static Codec> concurrentHashMap(Codec key, Codec value) { + return concurrentMap(ConcurrentHashMap::new, key, value); + } + + final Collector collector; + final Function> iterate; + final Codec elementCodec; + + public CollectCodec(Collector collector, Function> iterate, Codec codec) { + this.collector = collector; + this.iterate = iterate; + this.elementCodec = codec; + } + + @Override + public DataResult> decode(DynamicOps ops, X input) { + return ops.getList(input).setLifecycle(Lifecycle.stable()).flatMap(stream -> { + BiConsumer accumulator = this.collector.accumulator(); + A read = this.collector.supplier().get(); + + final Stream.Builder failed = Stream.builder(); + final MutableObject> result = new MutableObject<>(DataResult.success(Unit.INSTANCE, Lifecycle.stable())); + + stream.accept(t -> { + final DataResult> element = this.elementCodec.decode(ops, t); + element.error().ifPresent(e -> failed.add(t)); + result.setValue(result.getValue().apply2stable((r, v) -> { + accumulator.accept(read, v.getFirst()); + return r; + }, element)); + }); + + final R elements = this.collector.finisher().apply(read); + final X errors = ops.createList(failed.build()); + + final Pair pair = Pair.of(elements, errors); + return result.getValue().map(unit -> pair).setPartial(pair); + }); + } + + @Override + public DataResult encode(R input, DynamicOps ops, X prefix) { + final ListBuilder builder = ops.listBuilder(); + + Iterator apply = this.iterate.apply(input); + while(apply.hasNext()) { + T next = apply.next(); + DataResult result = this.elementCodec.encodeStart(ops, next); + builder.add(result); + } + return builder.build(prefix); + } + + @Override + public boolean equals(Object o) { + if(this == o) { + return true; + } + if(!(o instanceof CollectCodec)) { + return false; + } + + CollectCodec codec = (CollectCodec) o; + + if(!this.collector.equals(codec.collector)) { + return false; + } + if(!this.iterate.equals(codec.iterate)) { + return false; + } + return this.elementCodec.equals(codec.elementCodec); + } + + @Override + public int hashCode() { + int result = this.collector.hashCode(); + result = 31 * result + this.iterate.hashCode(); + result = 31 * result + this.elementCodec.hashCode(); + return result; + } +} diff --git a/src/main/java/com/mojang/serialization/codecs/MapEntryCodec.java b/src/main/java/com/mojang/serialization/codecs/MapEntryCodec.java new file mode 100644 index 00000000..331c0319 --- /dev/null +++ b/src/main/java/com/mojang/serialization/codecs/MapEntryCodec.java @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +package com.mojang.serialization.codecs; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import com.google.gson.JsonArray; +import com.mojang.datafixers.util.Pair; +import com.mojang.serialization.Codec; +import com.mojang.serialization.DataResult; +import com.mojang.serialization.DynamicOps; +import com.mojang.serialization.JsonOps; +import com.mojang.serialization.Lifecycle; +import com.mojang.serialization.ListBuilder; +import org.apache.commons.lang3.mutable.MutableInt; + +public final class MapEntryCodec implements Codec> { + private final Codec first; + private final Codec second; + + public static Codec> of(final Codec first, final Codec second) { + return new MapEntryCodec<>(first, second); + } + + public MapEntryCodec(final Codec first, final Codec second) { + this.first = first; + this.second = second; + } + + @Override + public DataResult, T>> decode(final DynamicOps ops, final T input) { + return ops.getList(input).setLifecycle(Lifecycle.stable()).flatMap(consumer -> { + List inputs = new ArrayList<>(3); + consumer.accept(inputs::add); + if(inputs.size() == 2) { + inputs.add(ops.empty()); + } else if(inputs.size() < 2) { + return DataResult.error("Expected atleast 2 elements for map entry, found " + inputs.size()); + } + return this.first + .decode(ops, inputs.get(0)) + .flatMap(p -> this.second.decode(ops, inputs.get(1)).map(p2 -> Pair.of(Map.entry(p.getFirst(), p2.getFirst()), inputs.get(2)))); + }); + } + + @Override + public DataResult encode(final Map.Entry value, final DynamicOps ops, final T rest) { + ListBuilder builder = ops.listBuilder(); + builder.add(this.first.encodeStart(ops, value.getKey())); + builder.add(this.second.encodeStart(ops, value.getValue())); + return builder.build(rest); + } + + @Override + public boolean equals(final Object o) { + if(this == o) { + return true; + } + if(o == null || this.getClass() != o.getClass()) { + return false; + } + final MapEntryCodec pairCodec = (MapEntryCodec) o; + return Objects.equals(this.first, pairCodec.first) && Objects.equals(this.second, pairCodec.second); + } + + @Override + public int hashCode() { + return Objects.hash(this.first, this.second); + } + + @Override + public String toString() { + return "Map.EntryCodec[" + this.first + ", " + this.second + ']'; + } +}