package org.elasticsearch.xpack.inference.services.cohere;

import java.util.List;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;

/* loaded from: input_file:org/elasticsearch/xpack/inference/services/cohere/CohereService.class */
public class CohereService extends SenderService {
    public static final String NAME = "cohere";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.elasticsearch.xpack.inference.services.cohere.CohereService$1, reason: invalid class name */
    /* loaded from: input_file:org/elasticsearch/xpack/inference/services/cohere/CohereService$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$elasticsearch$inference$TaskType = new int[TaskType.values().length];

        static {
            try {
                $SwitchMap$org$elasticsearch$inference$TaskType[TaskType.TEXT_EMBEDDING.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$elasticsearch$inference$TaskType[TaskType.RERANK.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public CohereService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
        super(factory, serviceComponents);
    }

    public String name() {
        return NAME;
    }

    public void parseRequestConfig(String str, TaskType taskType, Map<String, Object> map, Set<String> set, ActionListener<Model> actionListener) {
        try {
            Map<String, Object> removeFromMapOrThrowIfNull = ServiceUtils.removeFromMapOrThrowIfNull(map, "service_settings");
            Map<String, Object> removeFromMapOrDefaultEmpty = ServiceUtils.removeFromMapOrDefaultEmpty(map, "task_settings");
            CohereModel createModel = createModel(str, taskType, removeFromMapOrThrowIfNull, removeFromMapOrDefaultEmpty, removeFromMapOrThrowIfNull, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST);
            ServiceUtils.throwIfNotEmptyMap(map, NAME);
            ServiceUtils.throwIfNotEmptyMap(removeFromMapOrThrowIfNull, NAME);
            ServiceUtils.throwIfNotEmptyMap(removeFromMapOrDefaultEmpty, NAME);
            actionListener.onResponse(createModel);
        } catch (Exception e) {
            actionListener.onFailure(e);
        }
    }

    private static CohereModel createModelWithoutLoggingDeprecations(String str, TaskType taskType, Map<String, Object> map, Map<String, Object> map2, @Nullable Map<String, Object> map3, String str2) {
        return createModel(str, taskType, map, map2, map3, str2, ConfigurationParseContext.PERSISTENT);
    }

    private static CohereModel createModel(String str, TaskType taskType, Map<String, Object> map, Map<String, Object> map2, @Nullable Map<String, Object> map3, String str2, ConfigurationParseContext configurationParseContext) {
        switch (AnonymousClass1.$SwitchMap$org$elasticsearch$inference$TaskType[taskType.ordinal()]) {
            case 1:
                return new CohereEmbeddingsModel(str, taskType, NAME, map, map2, map3, configurationParseContext);
            case 2:
                return new CohereRerankModel(str, taskType, NAME, map, map2, map3, configurationParseContext);
            default:
                throw new ElasticsearchStatusException(str2, RestStatus.BAD_REQUEST, new Object[0]);
        }
    }

    public CohereModel parsePersistedConfigWithSecrets(String str, TaskType taskType, Map<String, Object> map, Map<String, Object> map2) {
        return createModelWithoutLoggingDeprecations(str, taskType, ServiceUtils.removeFromMapOrThrowIfNull(map, "service_settings"), ServiceUtils.removeFromMapOrThrowIfNull(map, "task_settings"), ServiceUtils.removeFromMapOrThrowIfNull(map2, "secret_settings"), ServiceUtils.parsePersistedConfigErrorMsg(str, NAME));
    }

    public CohereModel parsePersistedConfig(String str, TaskType taskType, Map<String, Object> map) {
        return createModelWithoutLoggingDeprecations(str, taskType, ServiceUtils.removeFromMapOrThrowIfNull(map, "service_settings"), ServiceUtils.removeFromMapOrThrowIfNull(map, "task_settings"), null, ServiceUtils.parsePersistedConfigErrorMsg(str, NAME));
    }

    @Override // org.elasticsearch.xpack.inference.services.SenderService
    public void doInfer(Model model, String str, List<String> list, Map<String, Object> map, InputType inputType, TimeValue timeValue, ActionListener<InferenceServiceResults> actionListener) {
        if (model instanceof CohereModel) {
            ((CohereModel) model).accept(new CohereActionCreator(getSender(), getServiceComponents()), map, inputType).execute(new QueryAndDocsInputs(str, list), timeValue, actionListener);
        } else {
            actionListener.onFailure(ServiceUtils.createInvalidModelException(model));
        }
    }

    @Override // org.elasticsearch.xpack.inference.services.SenderService
    public void doInfer(Model model, List<String> list, Map<String, Object> map, InputType inputType, TimeValue timeValue, ActionListener<InferenceServiceResults> actionListener) {
        if (model instanceof CohereModel) {
            ((CohereModel) model).accept(new CohereActionCreator(getSender(), getServiceComponents()), map, inputType).execute(new DocumentsOnlyInput(list), timeValue, actionListener);
        } else {
            actionListener.onFailure(ServiceUtils.createInvalidModelException(model));
        }
    }

    @Override // org.elasticsearch.xpack.inference.services.SenderService
    protected void doChunkedInfer(Model model, @Nullable String str, List<String> list, Map<String, Object> map, InputType inputType, ChunkingOptions chunkingOptions, TimeValue timeValue, ActionListener<List<ChunkedInferenceServiceResults>> actionListener) {
        if (!(model instanceof CohereModel)) {
            actionListener.onFailure(ServiceUtils.createInvalidModelException(model));
            return;
        }
        CohereModel cohereModel = (CohereModel) model;
        CohereActionCreator cohereActionCreator = new CohereActionCreator(getSender(), getServiceComponents());
        for (EmbeddingRequestChunker.BatchRequestAndListener batchRequestAndListener : new EmbeddingRequestChunker(list, 96).batchRequestsWithListeners(actionListener)) {
            cohereModel.accept(cohereActionCreator, map, inputType).execute(new DocumentsOnlyInput(batchRequestAndListener.batch().inputs()), timeValue, batchRequestAndListener.listener());
        }
    }

    public void checkModelConfig(Model model, ActionListener<Model> actionListener) {
        if (!(model instanceof CohereEmbeddingsModel)) {
            actionListener.onResponse(model);
        } else {
            CohereEmbeddingsModel cohereEmbeddingsModel = (CohereEmbeddingsModel) model;
            ServiceUtils.getEmbeddingSize(model, this, actionListener.delegateFailureAndWrap((actionListener2, num) -> {
                actionListener2.onResponse(updateModelWithEmbeddingDetails(cohereEmbeddingsModel, num.intValue()));
            }));
        }
    }

    private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel cohereEmbeddingsModel, int i) {
        SimilarityMeasure similarity = cohereEmbeddingsModel.m49getServiceSettings().similarity();
        return new CohereEmbeddingsModel(cohereEmbeddingsModel, new CohereEmbeddingsServiceSettings(new CohereServiceSettings(cohereEmbeddingsModel.m49getServiceSettings().getCommonSettings().uri(), similarity == null ? SimilarityMeasure.DOT_PRODUCT : similarity, Integer.valueOf(i), cohereEmbeddingsModel.m49getServiceSettings().getCommonSettings().maxInputTokens(), cohereEmbeddingsModel.m49getServiceSettings().getCommonSettings().modelId()), cohereEmbeddingsModel.m49getServiceSettings().getEmbeddingType()));
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
    }

    /* renamed from: parsePersistedConfig, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m41parsePersistedConfig(String str, TaskType taskType, Map map) {
        return parsePersistedConfig(str, taskType, (Map<String, Object>) map);
    }

    /* renamed from: parsePersistedConfigWithSecrets, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m42parsePersistedConfigWithSecrets(String str, TaskType taskType, Map map, Map map2) {
        return parsePersistedConfigWithSecrets(str, taskType, (Map<String, Object>) map, (Map<String, Object>) map2);
    }
}
