/*
 * Decompiled with CFR 0.152.
 */
package oracle.pgx.api.mllib;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.PgxVertex;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.frames.internal.PgxFrameImpl;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.FrameMetaData;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.GnnExplanationMetaData;
import oracle.pgx.api.internal.mllib.ModelMetadata;
import oracle.pgx.api.internal.mllib.SupervisedGnnExplainerConfig;
import oracle.pgx.api.internal.mllib.SupervisedGraphWiseModelMetadata;
import oracle.pgx.api.mllib.FileModelStorer;
import oracle.pgx.api.mllib.GnnExplanation;
import oracle.pgx.api.mllib.GraphWiseModel;
import oracle.pgx.api.mllib.SupervisedGnnExplainer;
import oracle.pgx.api.mllib.SupervisedGnnExplanation;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.ModelKind;
import oracle.pgx.config.mllib.SupervisedGraphWiseModelConfig;
import oracle.pgx.config.mllib.loss.LossFunction;

public class SupervisedGraphWiseModel
extends GraphWiseModel<SupervisedGraphWiseModelConfig, SupervisedGraphWiseModelMetadata, SupervisedGraphWiseModel> {
    public static final String ALGORITHM_NAME = "SupervisedGraphWise";

    public SupervisedGraphWiseModel(PgxSession session, Core core, Supplier<String> keystorePathSupplier, Supplier<char[]> keystorePasswordSupplier, BiFunction<PgxSession, Graph, PgxGraph> graphConstructor, SupervisedGraphWiseModelMetadata modelMetadata) {
        super(session, core, keystorePathSupplier, keystorePasswordSupplier, modelMetadata, graphConstructor);
    }

    public SupervisedGraphWiseModel(PgxSession session, Core core, Supplier<String> keystorePathSupplier, Supplier<char[]> keystorePasswordSupplier, BiFunction<PgxSession, Graph, PgxGraph> graphConstructor, ModelMetadata modelMetadata) {
        super(session, core, keystorePathSupplier, keystorePasswordSupplier, null, graphConstructor);
        if (modelMetadata.getModelKind() != this.getModelKind()) {
            ErrorMessages.throwException(IllegalArgumentException::new, (String)"UNEXPECTED_MODEL_KIND", (Object[])new Object[]{this.getModelKind(), modelMetadata.getModelKind()});
        } else {
            this.modelMetadata = (SupervisedGraphWiseModelMetadata)modelMetadata;
        }
    }

    @Override
    protected SupervisedGraphWiseModel getThis() {
        return this;
    }

    @Override
    protected ModelKind getModelKind() {
        return ModelKind.SUPERVISED_GRAPHWISE;
    }

    @Override
    public PgxFuture<Double> fitAsync(PgxGraph graph) {
        return this.core.fitSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId()).thenApply(metadata -> {
            this.modelMetadata = metadata;
            return ((SupervisedGraphWiseModelConfig)metadata.getConfig()).getTrainingLoss();
        });
    }

    @Override
    public PgxFuture<Double> fitAsync(PgxGraph trainGraph, PgxGraph valGraph) {
        return this.core.fitSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), trainGraph.getId(), valGraph.getId()).thenApply(metadata -> {
            this.modelMetadata = metadata;
            return ((SupervisedGraphWiseModelConfig)metadata.getConfig()).getTrainingLoss();
        });
    }

    @Override
    public PgxFuture<PgxFrame> getTrainingLogAsync() {
        return this.core.getTrainingLogSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName()).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    @Override
    public <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.core.inferEmbeddingsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), this.serializeVertices(vertices)).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    public <ID> PgxFuture<PgxFrame> inferLogitsAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.core.inferLogitsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), this.serializeVertices(vertices)).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    public <ID> PgxFrame inferLogits(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.inferLogitsAsync(graph, vertices).join();
    }

    public <ID> PgxFuture<PgxFrame> inferLabelsAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.inferLabelsAsync(graph, vertices, 0.0f);
    }

    public <ID> PgxFuture<PgxFrame> inferLabelsAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.core.inferLabelsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), this.serializeVertices(vertices), threshold).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    public <ID> PgxFuture<PgxFrame> evaluateLabelsAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.evaluateLabelsAsync(graph, vertices, 0.0f);
    }

    public <ID> PgxFuture<PgxFrame> evaluateLabelsAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.core.evaluateLabelsSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), this.serializeVertices(vertices), threshold).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    public <ID> PgxFrame inferLabels(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.inferLabelsAsync(graph, vertices, 0.0f).join();
    }

    public <ID> PgxFrame inferLabels(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.inferLabelsAsync(graph, vertices, threshold).join();
    }

    public <ID> PgxFrame evaluateLabels(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.evaluateLabelsAsync(graph, vertices, 0.0f).join();
    }

    public <ID> PgxFrame evaluateLabels(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.evaluateLabelsAsync(graph, vertices, threshold).join();
    }

    public <ID> PgxFuture<PgxFrame> inferAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.inferAsync(graph, vertices, 0.0f);
    }

    public <ID> PgxFuture<PgxFrame> inferAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.core.inferSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), this.serializeVertices(vertices), threshold).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    public <ID> PgxFuture<PgxFrame> evaluateAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.evaluateAsync(graph, vertices, 0.0f);
    }

    public <ID> PgxFuture<PgxFrame> evaluateAsync(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.core.evaluateSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), this.serializeVertices(vertices), threshold).thenApply(frameMetaData -> new PgxFrameImpl(this.session, this.core, (FrameMetaData)((Object)frameMetaData), this.keystorePathSupplier, this.keystorePasswordSupplier));
    }

    public <ID> PgxFrame infer(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.inferAsync(graph, vertices, 0.0f).join();
    }

    public <ID> PgxFrame infer(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.inferAsync(graph, vertices, threshold).join();
    }

    public <ID> PgxFrame evaluate(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.evaluateAsync(graph, vertices, 0.0f).join();
    }

    public <ID> PgxFrame evaluate(PgxGraph graph, Iterable<PgxVertex<ID>> vertices, float threshold) {
        return this.evaluateAsync(graph, vertices, threshold).join();
    }

    @Deprecated
    public <ID> PgxFuture<SupervisedGnnExplanation<ID>> inferAndGetExplanationAsync(PgxGraph graph, PgxVertex<ID> vertex) {
        return this.inferAndGetExplanationAsync(graph, vertex, 0.0f);
    }

    @Deprecated
    public <ID> PgxFuture<SupervisedGnnExplanation<ID>> inferAndGetExplanationAsync(PgxGraph graph, PgxVertex<ID> vertex, float threshold) {
        return this.inferAndExplainWithConfigAsync(graph, vertex, new SupervisedGnnExplainerConfig(), threshold);
    }

    @Deprecated
    public <ID> SupervisedGnnExplanation<ID> inferAndGetExplanation(PgxGraph graph, PgxVertex<ID> vertex) {
        return this.inferAndGetExplanationAsync(graph, vertex, 0.0f).join();
    }

    @Deprecated
    public <ID> SupervisedGnnExplanation<ID> inferAndGetExplanation(PgxGraph graph, PgxVertex<ID> vertex, float threshold) {
        return this.inferAndGetExplanationAsync(graph, vertex, threshold).join();
    }

    public PgxFuture<Void> storeAsync(String path, String key) throws ExecutionException, InterruptedException {
        return this.storeAsync(path, key, false);
    }

    public PgxFuture<Void> storeAsync(String path, String key, boolean overwrite) {
        return ((FileModelStorer)this.export().file().path(path).key(key).overwrite(overwrite)).storeAsync();
    }

    public void store(String path, String key) throws ExecutionException, InterruptedException {
        this.storeAsync(path, key).get();
    }

    public void store(String path, String key, boolean overwrite) throws ExecutionException, InterruptedException {
        this.storeAsync(path, key, overwrite).get();
    }

    @Deprecated
    public SupervisedGraphWiseModelConfig.LossFunction getLossFunction() {
        return ((SupervisedGraphWiseModelConfig)this.getConfig()).getLossFunction();
    }

    public LossFunction getLossFunctionClass() {
        return ((SupervisedGraphWiseModelConfig)this.getConfig()).getLossFunctionClass();
    }

    public GraphWisePredictionLayerConfig[] getPredictionLayerConfigs() {
        return ((SupervisedGraphWiseModelConfig)this.getConfig()).getPredictionLayerConfigs();
    }

    public Map<?, Float> getClassWeights() {
        return ((SupervisedGraphWiseModelConfig)this.getConfig()).getClassWeights();
    }

    public String getVertexTargetPropertyName() {
        return ((SupervisedGraphWiseModelConfig)this.getConfig()).getVertexTargetPropertyName();
    }

    <ID> PgxFuture<SupervisedGnnExplanation<ID>> inferAndExplainWithConfigAsync(PgxGraph graph, PgxVertex<ID> vertex, SupervisedGnnExplainerConfig config, float threshold) {
        return this.core.inferAndGetExplanationSupervisedGraphWiseModel(this.session.getSessionContext(), ((SupervisedGraphWiseModelMetadata)this.modelMetadata).getModelName(), graph.getId(), vertex.serialize(), threshold, config).thenApply(gnnExplanationResult -> {
            GnnExplanation explanation = this.processExplanationResult(graph, (GnnExplanationMetaData)gnnExplanationResult);
            return new SupervisedGnnExplanation(explanation.getVertexFeatureImportance(), explanation.getImportanceGraph(), explanation.getVertexImportanceProperty(), explanation.getEmbedding(), gnnExplanationResult.getLogits(), gnnExplanationResult.getLabel());
        });
    }

    public SupervisedGnnExplainer gnnExplainer() {
        return new SupervisedGnnExplainer(this);
    }

    public List<Set<String>> getTargetVertexLabels() {
        return ((SupervisedGraphWiseModelConfig)this.getConfig()).getTargetVertexLabelSets();
    }

    public static enum SupervisedGraphWiseInferenceType {
        INFER_EMBEDDINGS,
        INFER_LABELS,
        EVALUATE_LABELS,
        INFER_LOGITS,
        INFER,
        EVALUATE;

    }
}

