minecraft-src/net/minecraft/util/random/WeightedList.java
2025-07-04 03:45:38 +03:00

187 lines
5 KiB
Java

package net.minecraft.util.random;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.mojang.serialization.Codec;
import com.mojang.serialization.MapCodec;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import net.minecraft.util.ExtraCodecs;
import net.minecraft.util.RandomSource;
import org.jetbrains.annotations.Nullable;
public final class WeightedList<E> {
private static final int FLAT_THRESHOLD = 64;
private final int totalWeight;
private final List<Weighted<E>> items;
@Nullable
private final WeightedList.Selector<E> selector;
WeightedList(List<? extends Weighted<E>> items) {
this.items = List.copyOf(items);
this.totalWeight = WeightedRandom.getTotalWeight(items, Weighted::weight);
if (this.totalWeight == 0) {
this.selector = null;
} else if (this.totalWeight < 64) {
this.selector = new WeightedList.Flat<>(this.items, this.totalWeight);
} else {
this.selector = new WeightedList.Compact<>(this.items);
}
}
public static <E> WeightedList<E> of() {
return new WeightedList<>(List.of());
}
public static <E> WeightedList<E> of(E element) {
return new WeightedList<>(List.of(new Weighted<>(element, 1)));
}
@SafeVarargs
public static <E> WeightedList<E> of(Weighted<E>... items) {
return new WeightedList<>(List.of(items));
}
public static <E> WeightedList<E> of(List<Weighted<E>> items) {
return new WeightedList<>(items);
}
public static <E> WeightedList.Builder<E> builder() {
return new WeightedList.Builder<>();
}
public boolean isEmpty() {
return this.items.isEmpty();
}
public <T> WeightedList<T> map(Function<E, T> mapper) {
return new WeightedList(Lists.transform(this.items, weighted -> weighted.map((Function<T, E>)mapper)));
}
public Optional<E> getRandom(RandomSource random) {
if (this.selector == null) {
return Optional.empty();
} else {
int i = random.nextInt(this.totalWeight);
return Optional.of(this.selector.get(i));
}
}
public E getRandomOrThrow(RandomSource random) {
if (this.selector == null) {
throw new IllegalStateException("Weighted list has no elements");
} else {
int i = random.nextInt(this.totalWeight);
return this.selector.get(i);
}
}
public List<Weighted<E>> unwrap() {
return this.items;
}
public static <E> Codec<WeightedList<E>> codec(Codec<E> elementCodec) {
return Weighted.codec(elementCodec).listOf().xmap(WeightedList::of, WeightedList::unwrap);
}
public static <E> Codec<WeightedList<E>> codec(MapCodec<E> elementCodec) {
return Weighted.codec(elementCodec).listOf().xmap(WeightedList::of, WeightedList::unwrap);
}
public static <E> Codec<WeightedList<E>> nonEmptyCodec(Codec<E> elementCodec) {
return ExtraCodecs.nonEmptyList(Weighted.codec(elementCodec).listOf()).xmap(WeightedList::of, WeightedList::unwrap);
}
public static <E> Codec<WeightedList<E>> nonEmptyCodec(MapCodec<E> elementCodec) {
return ExtraCodecs.nonEmptyList(Weighted.codec(elementCodec).listOf()).xmap(WeightedList::of, WeightedList::unwrap);
}
public boolean contains(E element) {
for (Weighted<E> weighted : this.items) {
if (weighted.value().equals(element)) {
return true;
}
}
return false;
}
public boolean equals(@Nullable Object object) {
if (this == object) {
return true;
} else {
return !(object instanceof WeightedList<?> weightedList)
? false
: this.totalWeight == weightedList.totalWeight && Objects.equals(this.items, weightedList.items);
}
}
public int hashCode() {
int i = this.totalWeight;
return 31 * i + this.items.hashCode();
}
public static class Builder<E> {
private final ImmutableList.Builder<Weighted<E>> result = ImmutableList.builder();
public WeightedList.Builder<E> add(E element) {
return this.add(element, 1);
}
public WeightedList.Builder<E> add(E element, int weight) {
this.result.add(new Weighted<>(element, weight));
return this;
}
public WeightedList<E> build() {
return new WeightedList<>(this.result.build());
}
}
static class Compact<E> implements WeightedList.Selector<E> {
private final Weighted<?>[] entries;
Compact(List<Weighted<E>> entries) {
this.entries = (Weighted<?>[])entries.toArray(Weighted[]::new);
}
@Override
public E get(int index) {
for (Weighted<?> weighted : this.entries) {
index -= weighted.weight();
if (index < 0) {
return (E)weighted.value();
}
}
throw new IllegalStateException(index + " exceeded total weight");
}
}
static class Flat<E> implements WeightedList.Selector<E> {
private final Object[] entries;
Flat(List<Weighted<E>> entries, int size) {
this.entries = new Object[size];
int i = 0;
for (Weighted<E> weighted : entries) {
int j = weighted.weight();
Arrays.fill(this.entries, i, i + j, weighted.value());
i += j;
}
}
@Override
public E get(int index) {
return (E)this.entries[index];
}
}
interface Selector<E> {
E get(int index);
}
}