package net.minecraft.util.random; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.mojang.serialization.Codec; import com.mojang.serialization.MapCodec; import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.function.Function; import net.minecraft.util.ExtraCodecs; import net.minecraft.util.RandomSource; import org.jetbrains.annotations.Nullable; public final class WeightedList { private static final int FLAT_THRESHOLD = 64; private final int totalWeight; private final List> items; @Nullable private final WeightedList.Selector selector; WeightedList(List> items) { this.items = List.copyOf(items); this.totalWeight = WeightedRandom.getTotalWeight(items, Weighted::weight); if (this.totalWeight == 0) { this.selector = null; } else if (this.totalWeight < 64) { this.selector = new WeightedList.Flat<>(this.items, this.totalWeight); } else { this.selector = new WeightedList.Compact<>(this.items); } } public static WeightedList of() { return new WeightedList<>(List.of()); } public static WeightedList of(E element) { return new WeightedList<>(List.of(new Weighted<>(element, 1))); } @SafeVarargs public static WeightedList of(Weighted... items) { return new WeightedList<>(List.of(items)); } public static WeightedList of(List> items) { return new WeightedList<>(items); } public static WeightedList.Builder builder() { return new WeightedList.Builder<>(); } public boolean isEmpty() { return this.items.isEmpty(); } public WeightedList map(Function mapper) { return new WeightedList(Lists.transform(this.items, weighted -> weighted.map((Function)mapper))); } public Optional getRandom(RandomSource random) { if (this.selector == null) { return Optional.empty(); } else { int i = random.nextInt(this.totalWeight); return Optional.of(this.selector.get(i)); } } public E getRandomOrThrow(RandomSource random) { if (this.selector == null) { throw new IllegalStateException("Weighted list has no elements"); } else { int i = random.nextInt(this.totalWeight); return this.selector.get(i); } } public List> unwrap() { return this.items; } public static Codec> codec(Codec elementCodec) { return Weighted.codec(elementCodec).listOf().xmap(WeightedList::of, WeightedList::unwrap); } public static Codec> codec(MapCodec elementCodec) { return Weighted.codec(elementCodec).listOf().xmap(WeightedList::of, WeightedList::unwrap); } public static Codec> nonEmptyCodec(Codec elementCodec) { return ExtraCodecs.nonEmptyList(Weighted.codec(elementCodec).listOf()).xmap(WeightedList::of, WeightedList::unwrap); } public static Codec> nonEmptyCodec(MapCodec elementCodec) { return ExtraCodecs.nonEmptyList(Weighted.codec(elementCodec).listOf()).xmap(WeightedList::of, WeightedList::unwrap); } public boolean contains(E element) { for (Weighted weighted : this.items) { if (weighted.value().equals(element)) { return true; } } return false; } public boolean equals(@Nullable Object object) { if (this == object) { return true; } else { return !(object instanceof WeightedList weightedList) ? false : this.totalWeight == weightedList.totalWeight && Objects.equals(this.items, weightedList.items); } } public int hashCode() { int i = this.totalWeight; return 31 * i + this.items.hashCode(); } public static class Builder { private final ImmutableList.Builder> result = ImmutableList.builder(); public WeightedList.Builder add(E element) { return this.add(element, 1); } public WeightedList.Builder add(E element, int weight) { this.result.add(new Weighted<>(element, weight)); return this; } public WeightedList build() { return new WeightedList<>(this.result.build()); } } static class Compact implements WeightedList.Selector { private final Weighted[] entries; Compact(List> entries) { this.entries = (Weighted[])entries.toArray(Weighted[]::new); } @Override public E get(int index) { for (Weighted weighted : this.entries) { index -= weighted.weight(); if (index < 0) { return (E)weighted.value(); } } throw new IllegalStateException(index + " exceeded total weight"); } } static class Flat implements WeightedList.Selector { private final Object[] entries; Flat(List> entries, int size) { this.entries = new Object[size]; int i = 0; for (Weighted weighted : entries) { int j = weighted.weight(); Arrays.fill(this.entries, i, i + j, weighted.value()); i += j; } } @Override public E get(int index) { return (E)this.entries[index]; } } interface Selector { E get(int index); } }