package org.apache.lucene.util.hnsw;

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;

/* loaded from: input_file:WEB-INF/lib/lucene-core-9.2.0.jar:org/apache/lucene/util/hnsw/HnswGraphBuilder.class */
public final class HnswGraphBuilder {
    private static final long DEFAULT_RAND_SEED = 42;
    public static final String HNSW_COMPONENT = "HNSW";
    public static long randSeed;
    private final int M;
    private final int beamWidth;
    private final double ml;
    private final NeighborArray scratch;
    private final VectorSimilarityFunction similarityFunction;
    private final RandomAccessVectorValues vectorValues;
    private final SplittableRandom random;
    private final BoundsChecker bound;
    private final HnswGraphSearcher graphSearcher;
    final OnHeapHnswGraph hnsw;
    private InfoStream infoStream = InfoStream.getDefault();
    private RandomAccessVectorValues buildVectors;
    static final /* synthetic */ boolean $assertionsDisabled;

    public HnswGraphBuilder(RandomAccessVectorValuesProducer randomAccessVectorValuesProducer, VectorSimilarityFunction vectorSimilarityFunction, int i, int i2, long j) throws IOException {
        this.vectorValues = randomAccessVectorValuesProducer.randomAccess();
        this.buildVectors = randomAccessVectorValuesProducer.randomAccess();
        this.similarityFunction = (VectorSimilarityFunction) Objects.requireNonNull(vectorSimilarityFunction);
        if (i <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        this.M = i;
        this.beamWidth = i2;
        this.ml = 1.0d / Math.log(1.0d * i);
        this.random = new SplittableRandom(j);
        this.hnsw = new OnHeapHnswGraph(i, getRandomGraphLevel(this.ml, this.random), vectorSimilarityFunction.reversed);
        this.graphSearcher = new HnswGraphSearcher(vectorSimilarityFunction, new NeighborQueue(i2, !vectorSimilarityFunction.reversed), new FixedBitSet(this.vectorValues.size()));
        this.bound = BoundsChecker.create(vectorSimilarityFunction.reversed);
        this.scratch = new NeighborArray(Math.max(i2, i + 1), vectorSimilarityFunction.reversed);
    }

    public OnHeapHnswGraph build(RandomAccessVectorValues randomAccessVectorValues) throws IOException {
        if (randomAccessVectorValues == this.vectorValues) {
            throw new IllegalArgumentException("Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
        }
        if (this.infoStream.isEnabled(HNSW_COMPONENT)) {
            this.infoStream.message(HNSW_COMPONENT, "build graph from " + randomAccessVectorValues.size() + " vectors");
        }
        long nanoTime = System.nanoTime();
        long j = nanoTime;
        for (int i = 1; i < randomAccessVectorValues.size(); i++) {
            addGraphNode(i, randomAccessVectorValues.vectorValue(i));
            if (i % 10000 == 0 && this.infoStream.isEnabled(HNSW_COMPONENT)) {
                j = printGraphBuildStatus(i, nanoTime, j);
            }
        }
        return this.hnsw;
    }

    public void setInfoStream(InfoStream infoStream) {
        this.infoStream = infoStream;
    }

    void addGraphNode(int i, float[] fArr) throws IOException {
        int randomGraphLevel = getRandomGraphLevel(this.ml, this.random);
        int numLevels = this.hnsw.numLevels() - 1;
        int[] iArr = {this.hnsw.entryNode()};
        for (int i2 = randomGraphLevel; i2 > numLevels; i2--) {
            this.hnsw.addNode(i2, i);
        }
        for (int i3 = numLevels; i3 > randomGraphLevel; i3--) {
            iArr = new int[]{this.graphSearcher.searchLevel(fArr, 1, i3, iArr, this.vectorValues, this.hnsw).pop()};
        }
        for (int min = Math.min(randomGraphLevel, numLevels); min >= 0; min--) {
            NeighborQueue searchLevel = this.graphSearcher.searchLevel(fArr, this.beamWidth, min, iArr, this.vectorValues, this.hnsw);
            iArr = searchLevel.nodes();
            this.hnsw.addNode(min, i);
            addDiverseNeighbors(min, i, searchLevel);
        }
    }

    private long printGraphBuildStatus(int i, long j, long j2) {
        long nanoTime = System.nanoTime();
        this.infoStream.message(HNSW_COMPONENT, String.format(Locale.ROOT, "built %d in %d/%d ms", Integer.valueOf(i), Long.valueOf((nanoTime - j2) / 1000000), Long.valueOf((nanoTime - j) / 1000000)));
        return nanoTime;
    }

    private void addDiverseNeighbors(int i, int i2, NeighborQueue neighborQueue) throws IOException {
        NeighborArray neighbors = this.hnsw.getNeighbors(i, i2);
        if (!$assertionsDisabled && neighbors.size() != 0) {
            throw new AssertionError();
        }
        popToScratch(neighborQueue);
        int i3 = i == 0 ? this.M * 2 : this.M;
        selectAndLinkDiverse(neighbors, this.scratch, i3);
        int size = neighbors.size();
        for (int i4 = 0; i4 < size; i4++) {
            NeighborArray neighbors2 = this.hnsw.getNeighbors(i, neighbors.node[i4]);
            neighbors2.insertSorted(i2, neighbors.score[i4]);
            if (neighbors2.size() > i3) {
                neighbors2.removeIndex(findWorstNonDiverse(neighbors2));
            }
        }
    }

    private void selectAndLinkDiverse(NeighborArray neighborArray, NeighborArray neighborArray2, int i) throws IOException {
        for (int size = neighborArray2.size() - 1; neighborArray.size() < i && size >= 0; size--) {
            int i2 = neighborArray2.node[size];
            float f = neighborArray2.score[size];
            if (!$assertionsDisabled && i2 >= this.hnsw.size()) {
                throw new AssertionError();
            }
            if (diversityCheck(this.vectorValues.vectorValue(i2), f, neighborArray, this.buildVectors)) {
                neighborArray.add(i2, f);
            }
        }
    }

    private void popToScratch(NeighborQueue neighborQueue) {
        this.scratch.clear();
        int size = neighborQueue.size();
        for (int i = 0; i < size; i++) {
            this.scratch.add(neighborQueue.pop(), neighborQueue.topScore());
        }
    }

    private boolean diversityCheck(float[] fArr, float f, NeighborArray neighborArray, RandomAccessVectorValues randomAccessVectorValues) throws IOException {
        this.bound.set(f);
        for (int i = 0; i < neighborArray.size(); i++) {
            if (!this.bound.check(this.similarityFunction.compare(fArr, randomAccessVectorValues.vectorValue(neighborArray.node[i])))) {
                return false;
            }
        }
        return true;
    }

    private int findWorstNonDiverse(NeighborArray neighborArray) throws IOException {
        for (int size = neighborArray.size() - 1; size > 0; size--) {
            float[] vectorValue = this.vectorValues.vectorValue(neighborArray.node[size]);
            this.bound.set(neighborArray.score[size]);
            for (int i = size - 1; i >= 0; i--) {
                if (!this.bound.check(this.similarityFunction.compare(vectorValue, this.buildVectors.vectorValue(neighborArray.node[i])))) {
                    return size;
                }
            }
        }
        return neighborArray.size() - 1;
    }

    private static int getRandomGraphLevel(double d, SplittableRandom splittableRandom) {
        double nextDouble;
        do {
            nextDouble = splittableRandom.nextDouble();
        } while (nextDouble == 0.0d);
        return (int) ((-Math.log(nextDouble)) * d);
    }

    static {
        $assertionsDisabled = !HnswGraphBuilder.class.desiredAssertionStatus();
        randSeed = DEFAULT_RAND_SEED;
    }
}
