package net.minecraft.world.entity.ai; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.ImmutableList.Builder; import com.mojang.datafixers.util.Pair; import com.mojang.logging.LogUtils; import com.mojang.serialization.Codec; import com.mojang.serialization.DataResult; import com.mojang.serialization.Dynamic; import com.mojang.serialization.DynamicOps; import com.mojang.serialization.MapCodec; import com.mojang.serialization.MapLike; import com.mojang.serialization.RecordBuilder; import it.unimi.dsi.fastutil.objects.ObjectArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.Map.Entry; import java.util.function.Supplier; import java.util.stream.Stream; import net.minecraft.core.registries.BuiltInRegistries; import net.minecraft.server.level.ServerLevel; import net.minecraft.util.VisibleForDebug; import net.minecraft.world.entity.LivingEntity; import net.minecraft.world.entity.ai.behavior.BehaviorControl; import net.minecraft.world.entity.ai.behavior.Behavior.Status; import net.minecraft.world.entity.ai.memory.ExpirableValue; import net.minecraft.world.entity.ai.memory.MemoryModuleType; import net.minecraft.world.entity.ai.memory.MemoryStatus; import net.minecraft.world.entity.ai.sensing.Sensor; import net.minecraft.world.entity.ai.sensing.SensorType; import net.minecraft.world.entity.schedule.Activity; import net.minecraft.world.entity.schedule.Schedule; import org.apache.commons.lang3.mutable.MutableObject; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; public class Brain { static final Logger LOGGER = LogUtils.getLogger(); private final Supplier>> codec; private static final int SCHEDULE_UPDATE_DELAY = 20; private final Map, Optional>> memories = Maps., Optional>>newHashMap(); private final Map>, Sensor> sensors = Maps.>, Sensor>newLinkedHashMap(); private final Map>>> availableBehaviorsByPriority = Maps.newTreeMap(); private Schedule schedule = Schedule.EMPTY; private final Map, MemoryStatus>>> activityRequirements = Maps., MemoryStatus>>>newHashMap(); private final Map>> activityMemoriesToEraseWhenStopped = Maps.>>newHashMap(); private Set coreActivities = Sets.newHashSet(); private final Set activeActivities = Sets.newHashSet(); private Activity defaultActivity = Activity.IDLE; private long lastScheduleUpdate = -9999L; public static Brain.Provider provider( Collection> memoryTypes, Collection>> sensorTypes ) { return new Brain.Provider<>(memoryTypes, sensorTypes); } public static Codec> codec( Collection> memoryTypes, Collection>> sensorTypes ) { final MutableObject>> mutableObject = new MutableObject<>(); mutableObject.setValue( (new MapCodec>() { @Override public Stream keys(DynamicOps dynamicOps) { return memoryTypes.stream() .flatMap(memoryModuleType -> memoryModuleType.getCodec().map(codec -> BuiltInRegistries.MEMORY_MODULE_TYPE.getKey(memoryModuleType)).stream()) .map(resourceLocation -> dynamicOps.createString(resourceLocation.toString())); } @Override public DataResult> decode(DynamicOps dynamicOps, MapLike mapLike) { MutableObject>>> mutableObjectx = new MutableObject<>(DataResult.success(ImmutableList.builder())); mapLike.entries() .forEach( pair -> { DataResult> dataResult = BuiltInRegistries.MEMORY_MODULE_TYPE.byNameCodec().parse(dynamicOps, (T)pair.getFirst()); DataResult> dataResult2 = dataResult.flatMap( memoryModuleType -> this.captureRead(memoryModuleType, dynamicOps, (T)pair.getSecond()) ); mutableObject.setValue(mutableObject.getValue().apply2(Builder::add, dataResult2)); } ); ImmutableList> immutableList = (ImmutableList>)mutableObjectx.getValue() .resultOrPartial(Brain.LOGGER::error) .map(Builder::build) .orElseGet(ImmutableList::of); return DataResult.success(new Brain<>(memoryTypes, sensorTypes, immutableList, mutableObject::getValue)); } private DataResult> captureRead(MemoryModuleType memoryModuleType, DynamicOps dynamicOps, T object) { return ((DataResult)memoryModuleType.getCodec() .map(DataResult::success) .orElseGet(() -> DataResult.error(() -> "No codec for memory: " + memoryModuleType))) .flatMap(codec -> codec.parse(dynamicOps, object)) .map(expirableValue -> new Brain.MemoryValue<>(memoryModuleType, Optional.of(expirableValue))); } public RecordBuilder encode(Brain input, DynamicOps ops, RecordBuilder prefix) { input.memories().forEach(memoryValue -> memoryValue.serialize(ops, prefix)); return prefix; } }) .fieldOf("memories") .codec() ); return mutableObject.getValue(); } public Brain( Collection> memoryModuleTypes, Collection>> sensorTypes, ImmutableList> memoryValues, Supplier>> codec ) { this.codec = codec; for (MemoryModuleType memoryModuleType : memoryModuleTypes) { this.memories.put(memoryModuleType, Optional.empty()); } for (SensorType> sensorType : sensorTypes) { this.sensors.put(sensorType, sensorType.create()); } for (Sensor sensor : this.sensors.values()) { for (MemoryModuleType memoryModuleType2 : sensor.requires()) { this.memories.put(memoryModuleType2, Optional.empty()); } } for (Brain.MemoryValue memoryValue : memoryValues) { memoryValue.setMemoryInternal(this); } } public DataResult serializeStart(DynamicOps ops) { return ((Codec)this.codec.get()).encodeStart(ops, this); } Stream> memories() { return this.memories .entrySet() .stream() .map(entry -> Brain.MemoryValue.createUnchecked((MemoryModuleType)entry.getKey(), (Optional>)entry.getValue())); } public boolean hasMemoryValue(MemoryModuleType type) { return this.checkMemory(type, MemoryStatus.VALUE_PRESENT); } public void clearMemories() { this.memories.keySet().forEach(memoryModuleType -> this.memories.put(memoryModuleType, Optional.empty())); } public void eraseMemory(MemoryModuleType type) { this.setMemory(type, Optional.empty()); } public void setMemory(MemoryModuleType memoryType, @Nullable U memory) { this.setMemory(memoryType, Optional.ofNullable(memory)); } public void setMemoryWithExpiry(MemoryModuleType memoryType, U memory, long timeToLive) { this.setMemoryInternal(memoryType, Optional.of(ExpirableValue.of(memory, timeToLive))); } public void setMemory(MemoryModuleType memoryType, Optional memory) { this.setMemoryInternal(memoryType, memory.map(ExpirableValue::of)); } void setMemoryInternal(MemoryModuleType memoryType, Optional> memory) { if (this.memories.containsKey(memoryType)) { if (memory.isPresent() && this.isEmptyCollection(((ExpirableValue)memory.get()).getValue())) { this.eraseMemory(memoryType); } else { this.memories.put(memoryType, memory); } } } public Optional getMemory(MemoryModuleType type) { Optional> optional = (Optional>)this.memories.get(type); if (optional == null) { throw new IllegalStateException("Unregistered memory fetched: " + type); } else { return optional.map(ExpirableValue::getValue); } } @Nullable public Optional getMemoryInternal(MemoryModuleType type) { Optional> optional = (Optional>)this.memories.get(type); return optional == null ? null : optional.map(ExpirableValue::getValue); } public long getTimeUntilExpiry(MemoryModuleType memoryType) { Optional> optional = (Optional>)this.memories.get(memoryType); return (Long)optional.map(ExpirableValue::getTimeToLive).orElse(0L); } @Deprecated @VisibleForDebug public Map, Optional>> getMemories() { return this.memories; } public boolean isMemoryValue(MemoryModuleType memoryType, U memory) { return !this.hasMemoryValue(memoryType) ? false : this.getMemory(memoryType).filter(object2 -> object2.equals(memory)).isPresent(); } public boolean checkMemory(MemoryModuleType memoryType, MemoryStatus memoryStatus) { Optional> optional = (Optional>)this.memories.get(memoryType); return optional == null ? false : memoryStatus == MemoryStatus.REGISTERED || memoryStatus == MemoryStatus.VALUE_PRESENT && optional.isPresent() || memoryStatus == MemoryStatus.VALUE_ABSENT && optional.isEmpty(); } public Schedule getSchedule() { return this.schedule; } public void setSchedule(Schedule newSchedule) { this.schedule = newSchedule; } public void setCoreActivities(Set newActivities) { this.coreActivities = newActivities; } @Deprecated @VisibleForDebug public Set getActiveActivities() { return this.activeActivities; } @Deprecated @VisibleForDebug public List> getRunningBehaviors() { List> list = new ObjectArrayList<>(); for (Map>> map : this.availableBehaviorsByPriority.values()) { for (Set> set : map.values()) { for (BehaviorControl behaviorControl : set) { if (behaviorControl.getStatus() == Status.RUNNING) { list.add(behaviorControl); } } } } return list; } public void useDefaultActivity() { this.setActiveActivity(this.defaultActivity); } public Optional getActiveNonCoreActivity() { for (Activity activity : this.activeActivities) { if (!this.coreActivities.contains(activity)) { return Optional.of(activity); } } return Optional.empty(); } public void setActiveActivityIfPossible(Activity activity) { if (this.activityRequirementsAreMet(activity)) { this.setActiveActivity(activity); } else { this.useDefaultActivity(); } } private void setActiveActivity(Activity activity) { if (!this.isActive(activity)) { this.eraseMemoriesForOtherActivitesThan(activity); this.activeActivities.clear(); this.activeActivities.addAll(this.coreActivities); this.activeActivities.add(activity); } } private void eraseMemoriesForOtherActivitesThan(Activity activity) { for (Activity activity2 : this.activeActivities) { if (activity2 != activity) { Set> set = (Set>)this.activityMemoriesToEraseWhenStopped.get(activity2); if (set != null) { for (MemoryModuleType memoryModuleType : set) { this.eraseMemory(memoryModuleType); } } } } } public void updateActivityFromSchedule(long dayTime, long gameTime) { if (gameTime - this.lastScheduleUpdate > 20L) { this.lastScheduleUpdate = gameTime; Activity activity = this.getSchedule().getActivityAt((int)(dayTime % 24000L)); if (!this.activeActivities.contains(activity)) { this.setActiveActivityIfPossible(activity); } } } public void setActiveActivityToFirstValid(List activities) { for (Activity activity : activities) { if (this.activityRequirementsAreMet(activity)) { this.setActiveActivity(activity); break; } } } public void setDefaultActivity(Activity newFallbackActivity) { this.defaultActivity = newFallbackActivity; } public void addActivity(Activity activity, int priorityStart, ImmutableList> tasks) { this.addActivity(activity, this.createPriorityPairs(priorityStart, tasks)); } public void addActivityAndRemoveMemoryWhenStopped( Activity activity, int priorityStart, ImmutableList> tasks, MemoryModuleType memoryType ) { Set, MemoryStatus>> set = ImmutableSet.of(Pair.of(memoryType, MemoryStatus.VALUE_PRESENT)); Set> set2 = ImmutableSet.of(memoryType); this.addActivityAndRemoveMemoriesWhenStopped(activity, this.createPriorityPairs(priorityStart, tasks), set, set2); } public void addActivity(Activity activity, ImmutableList>> tasks) { this.addActivityAndRemoveMemoriesWhenStopped(activity, tasks, ImmutableSet.of(), Sets.>newHashSet()); } public void addActivityWithConditions( Activity activity, int priorityStart, ImmutableList> tasks, Set, MemoryStatus>> memoryStatuses ) { this.addActivityWithConditions(activity, this.createPriorityPairs(priorityStart, tasks), memoryStatuses); } public void addActivityWithConditions( Activity activity, ImmutableList>> tasks, Set, MemoryStatus>> memoryStatuses ) { this.addActivityAndRemoveMemoriesWhenStopped(activity, tasks, memoryStatuses, Sets.>newHashSet()); } public void addActivityAndRemoveMemoriesWhenStopped( Activity activity, ImmutableList>> tasks, Set, MemoryStatus>> memorieStatuses, Set> memoryTypes ) { this.activityRequirements.put(activity, memorieStatuses); if (!memoryTypes.isEmpty()) { this.activityMemoriesToEraseWhenStopped.put(activity, memoryTypes); } for (Pair> pair : tasks) { ((Set)((Map)this.availableBehaviorsByPriority.computeIfAbsent(pair.getFirst(), integer -> Maps.newHashMap())) .computeIfAbsent(activity, activityx -> Sets.newLinkedHashSet())) .add(pair.getSecond()); } } @VisibleForTesting public void removeAllBehaviors() { this.availableBehaviorsByPriority.clear(); } public boolean isActive(Activity activity) { return this.activeActivities.contains(activity); } public Brain copyWithoutBehaviors() { Brain brain = new Brain<>(this.memories.keySet(), this.sensors.keySet(), ImmutableList.of(), this.codec); for (Entry, Optional>> entry : this.memories.entrySet()) { MemoryModuleType memoryModuleType = (MemoryModuleType)entry.getKey(); if (((Optional)entry.getValue()).isPresent()) { brain.memories.put(memoryModuleType, (Optional)entry.getValue()); } } return brain; } public void tick(ServerLevel level, E entity) { this.forgetOutdatedMemories(); this.tickSensors(level, entity); this.startEachNonRunningBehavior(level, entity); this.tickEachRunningBehavior(level, entity); } private void tickSensors(ServerLevel level, E brainHolder) { for (Sensor sensor : this.sensors.values()) { sensor.tick(level, brainHolder); } } private void forgetOutdatedMemories() { for (Entry, Optional>> entry : this.memories.entrySet()) { if (((Optional)entry.getValue()).isPresent()) { ExpirableValue expirableValue = (ExpirableValue)((Optional)entry.getValue()).get(); if (expirableValue.hasExpired()) { this.eraseMemory((MemoryModuleType)entry.getKey()); } expirableValue.tick(); } } } public void stopAll(ServerLevel level, E owner) { long l = owner.level().getGameTime(); for (BehaviorControl behaviorControl : this.getRunningBehaviors()) { behaviorControl.doStop(level, owner, l); } } private void startEachNonRunningBehavior(ServerLevel level, E entity) { long l = level.getGameTime(); for (Map>> map : this.availableBehaviorsByPriority.values()) { for (Entry>> entry : map.entrySet()) { Activity activity = (Activity)entry.getKey(); if (this.activeActivities.contains(activity)) { for (BehaviorControl behaviorControl : (Set)entry.getValue()) { if (behaviorControl.getStatus() == Status.STOPPED) { behaviorControl.tryStart(level, entity, l); } } } } } } private void tickEachRunningBehavior(ServerLevel level, E entity) { long l = level.getGameTime(); for (BehaviorControl behaviorControl : this.getRunningBehaviors()) { behaviorControl.tickOrStop(level, entity, l); } } private boolean activityRequirementsAreMet(Activity activity) { if (!this.activityRequirements.containsKey(activity)) { return false; } else { for (Pair, MemoryStatus> pair : (Set)this.activityRequirements.get(activity)) { MemoryModuleType memoryModuleType = pair.getFirst(); MemoryStatus memoryStatus = pair.getSecond(); if (!this.checkMemory(memoryModuleType, memoryStatus)) { return false; } } return true; } } private boolean isEmptyCollection(Object collection) { return collection instanceof Collection && ((Collection)collection).isEmpty(); } ImmutableList>> createPriorityPairs( int priorityStart, ImmutableList> tasks ) { int i = priorityStart; Builder>> builder = ImmutableList.builder(); for (BehaviorControl behaviorControl : tasks) { builder.add(Pair.of(i++, behaviorControl)); } return builder.build(); } static final class MemoryValue { private final MemoryModuleType type; private final Optional> value; static Brain.MemoryValue createUnchecked(MemoryModuleType memoryType, Optional> memory) { return new Brain.MemoryValue<>(memoryType, (Optional>)memory); } MemoryValue(MemoryModuleType type, Optional> value) { this.type = type; this.value = value; } void setMemoryInternal(Brain brain) { brain.setMemoryInternal(this.type, this.value); } public void serialize(DynamicOps ops, RecordBuilder builder) { this.type .getCodec() .ifPresent( codec -> this.value .ifPresent( expirableValue -> builder.add(BuiltInRegistries.MEMORY_MODULE_TYPE.byNameCodec().encodeStart(ops, this.type), codec.encodeStart(ops, expirableValue)) ) ); } } public static final class Provider { private final Collection> memoryTypes; private final Collection>> sensorTypes; private final Codec> codec; Provider(Collection> memoryTypes, Collection>> sensorTypes) { this.memoryTypes = memoryTypes; this.sensorTypes = sensorTypes; this.codec = Brain.codec(memoryTypes, sensorTypes); } public Brain makeBrain(Dynamic ops) { return (Brain)this.codec .parse(ops) .resultOrPartial(Brain.LOGGER::error) .orElseGet(() -> new Brain(this.memoryTypes, this.sensorTypes, ImmutableList.of(), () -> this.codec)); } } }