package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/sentiment/SentimentCostAndGradient.class */
public class SentimentCostAndGradient extends AbstractCachingDiffFunction {
    private static final Redwood.RedwoodChannels log = Redwood.channels(SentimentCostAndGradient.class);
    private final SentimentModel model;
    private final List<Tree> trainingBatch;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/sentiment/SentimentCostAndGradient$ModelDerivatives.class */
    public static class ModelDerivatives {
        public final TwoDimensionalMap<String, String, SimpleMatrix> binaryTD;
        public final TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD;
        public final TwoDimensionalMap<String, String, SimpleMatrix> binaryCD;
        public final Map<String, SimpleMatrix> unaryCD;
        public final Map<String, SimpleMatrix> wordVectorD;
        public double error = 0.0d;

        public ModelDerivatives(SentimentModel sentimentModel) {
            this.binaryTD = initDerivatives(sentimentModel.binaryTransform);
            this.binaryTensorTD = sentimentModel.op.useTensors ? initTensorDerivatives(sentimentModel.binaryTensors) : TwoDimensionalMap.treeMap();
            this.binaryCD = !sentimentModel.op.combineClassification ? initDerivatives(sentimentModel.binaryClassification) : TwoDimensionalMap.treeMap();
            this.unaryCD = initDerivatives(sentimentModel.unaryClassification);
            this.wordVectorD = Generics.newTreeMap();
        }

        public void add(ModelDerivatives modelDerivatives) {
            addMatrices(this.binaryTD, modelDerivatives.binaryTD);
            addTensors(this.binaryTensorTD, modelDerivatives.binaryTensorTD);
            addMatrices(this.binaryCD, modelDerivatives.binaryCD);
            addMatrices(this.unaryCD, modelDerivatives.unaryCD);
            addMatrices(this.wordVectorD, modelDerivatives.wordVectorD);
            this.error += modelDerivatives.error;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public static void addMatrices(TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap2) {
            Iterator it = twoDimensionalMap.iterator();
            while (it.hasNext()) {
                TwoDimensionalMap.Entry entry = (TwoDimensionalMap.Entry) it.next();
                if (twoDimensionalMap2.contains(entry.getFirstKey(), entry.getSecondKey())) {
                    twoDimensionalMap.put(entry.getFirstKey(), entry.getSecondKey(), ((SimpleMatrix) entry.getValue()).plus((SimpleMatrix) twoDimensionalMap2.get(entry.getFirstKey(), entry.getSecondKey())));
                }
            }
            Iterator it2 = twoDimensionalMap2.iterator();
            while (it2.hasNext()) {
                TwoDimensionalMap.Entry entry2 = (TwoDimensionalMap.Entry) it2.next();
                if (!twoDimensionalMap.contains(entry2.getFirstKey(), entry2.getSecondKey())) {
                    twoDimensionalMap.put(entry2.getFirstKey(), entry2.getSecondKey(), entry2.getValue());
                }
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public static void addTensors(TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap2) {
            Iterator it = twoDimensionalMap.iterator();
            while (it.hasNext()) {
                TwoDimensionalMap.Entry entry = (TwoDimensionalMap.Entry) it.next();
                if (twoDimensionalMap2.contains(entry.getFirstKey(), entry.getSecondKey())) {
                    twoDimensionalMap.put(entry.getFirstKey(), entry.getSecondKey(), ((SimpleTensor) entry.getValue()).plus((SimpleTensor) twoDimensionalMap2.get(entry.getFirstKey(), entry.getSecondKey())));
                }
            }
            Iterator it2 = twoDimensionalMap2.iterator();
            while (it2.hasNext()) {
                TwoDimensionalMap.Entry entry2 = (TwoDimensionalMap.Entry) it2.next();
                if (!twoDimensionalMap.contains(entry2.getFirstKey(), entry2.getSecondKey())) {
                    twoDimensionalMap.put(entry2.getFirstKey(), entry2.getSecondKey(), entry2.getValue());
                }
            }
        }

        public static void addMatrices(Map<String, SimpleMatrix> map, Map<String, SimpleMatrix> map2) {
            for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
                if (map2.containsKey(entry.getKey())) {
                    map.put(entry.getKey(), entry.getValue().plus(map2.get(entry.getKey())));
                }
            }
            for (Map.Entry<String, SimpleMatrix> entry2 : map2.entrySet()) {
                if (!map.containsKey(entry2.getKey())) {
                    map.put(entry2.getKey(), entry2.getValue());
                }
            }
        }

        private static TwoDimensionalMap<String, String, SimpleMatrix> initDerivatives(TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap) {
            TwoDimensionalMap<String, String, SimpleMatrix> treeMap = TwoDimensionalMap.treeMap();
            Iterator<TwoDimensionalMap.Entry<String, String, SimpleMatrix>> it = twoDimensionalMap.iterator();
            while (it.hasNext()) {
                TwoDimensionalMap.Entry<String, String, SimpleMatrix> next = it.next();
                treeMap.put(next.getFirstKey(), next.getSecondKey(), new SimpleMatrix(next.getValue().numRows(), next.getValue().numCols()));
            }
            return treeMap;
        }

        private static TwoDimensionalMap<String, String, SimpleTensor> initTensorDerivatives(TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap) {
            TwoDimensionalMap<String, String, SimpleTensor> treeMap = TwoDimensionalMap.treeMap();
            Iterator<TwoDimensionalMap.Entry<String, String, SimpleTensor>> it = twoDimensionalMap.iterator();
            while (it.hasNext()) {
                TwoDimensionalMap.Entry<String, String, SimpleTensor> next = it.next();
                treeMap.put(next.getFirstKey(), next.getSecondKey(), new SimpleTensor(next.getValue().numRows(), next.getValue().numCols(), next.getValue().numSlices()));
            }
            return treeMap;
        }

        private static Map<String, SimpleMatrix> initDerivatives(Map<String, SimpleMatrix> map) {
            TreeMap newTreeMap = Generics.newTreeMap();
            for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
                newTreeMap.put(entry.getKey(), new SimpleMatrix(entry.getValue().numRows(), entry.getValue().numCols()));
            }
            return newTreeMap;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/sentiment/SentimentCostAndGradient$ScoringProcessor.class */
    class ScoringProcessor implements ThreadsafeProcessor<List<Tree>, ModelDerivatives> {
        ScoringProcessor() {
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public ModelDerivatives process(List<Tree> list) {
            return SentimentCostAndGradient.this.scoreDerivatives(list);
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance */
        public ThreadsafeProcessor<List<Tree>, ModelDerivatives> newInstance2() {
            return this;
        }
    }

    public SentimentCostAndGradient(SentimentModel sentimentModel, List<Tree> list) {
        this.model = sentimentModel;
        this.trainingBatch = list;
    }

    @Override // edu.stanford.nlp.optimization.Function
    public int domainDimension() {
        return this.model.totalParamSize();
    }

    private static double sumError(Tree tree) {
        if (tree.isLeaf()) {
            return 0.0d;
        }
        if (tree.isPreTerminal()) {
            return RNNCoreAnnotations.getPredictionError(tree);
        }
        double d = 0.0d;
        for (Tree tree2 : tree.children()) {
            d += sumError(tree2);
        }
        return RNNCoreAnnotations.getPredictionError(tree) + d;
    }

    private static int getPredictedClass(SimpleMatrix simpleMatrix) {
        int i = 0;
        for (int i2 = 1; i2 < simpleMatrix.getNumElements(); i2++) {
            if (simpleMatrix.get(i2) > simpleMatrix.get(i)) {
                i = i2;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public ModelDerivatives scoreDerivatives(List<Tree> list) {
        ModelDerivatives modelDerivatives = new ModelDerivatives(this.model);
        ArrayList<Tree> newArrayList = Generics.newArrayList();
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            Tree deepCopy = it.next().deepCopy();
            try {
                forwardPropagateTree(deepCopy);
                newArrayList.add(deepCopy);
            } catch (ForwardPropagationException e) {
                log.error("Illegal tree: " + deepCopy);
                throw e;
            }
        }
        for (Tree tree : newArrayList) {
            backpropDerivativesAndError(tree, modelDerivatives.binaryTD, modelDerivatives.binaryCD, modelDerivatives.binaryTensorTD, modelDerivatives.unaryCD, modelDerivatives.wordVectorD);
            modelDerivatives.error += sumError(tree);
        }
        return modelDerivatives;
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    public void calculate(double[] dArr) {
        ModelDerivatives modelDerivatives;
        this.model.vectorToParams(dArr);
        if (this.model.op.trainOptions.nThreads == 1) {
            modelDerivatives = scoreDerivatives(this.trainingBatch);
        } else {
            MulticoreWrapper multicoreWrapper = new MulticoreWrapper(this.model.op.trainOptions.nThreads, new ScoringProcessor());
            Iterator it = CollectionUtils.partitionIntoFolds(this.trainingBatch, multicoreWrapper.nThreads()).iterator();
            while (it.hasNext()) {
                multicoreWrapper.put((List) it.next());
            }
            multicoreWrapper.join();
            modelDerivatives = new ModelDerivatives(this.model);
            while (multicoreWrapper.peek()) {
                modelDerivatives.add((ModelDerivatives) multicoreWrapper.poll());
            }
        }
        double size = 1.0d / this.trainingBatch.size();
        this.value = modelDerivatives.error * size;
        this.value += scaleAndRegularize(modelDerivatives.binaryTD, this.model.binaryTransform, size, this.model.op.trainOptions.regTransformMatrix, false);
        this.value += scaleAndRegularize(modelDerivatives.binaryCD, this.model.binaryClassification, size, this.model.op.trainOptions.regClassification, true);
        this.value += scaleAndRegularizeTensor(modelDerivatives.binaryTensorTD, this.model.binaryTensors, size, this.model.op.trainOptions.regTransformTensor);
        this.value += scaleAndRegularize(modelDerivatives.unaryCD, this.model.unaryClassification, size, this.model.op.trainOptions.regClassification, false, true);
        this.value += scaleAndRegularize(modelDerivatives.wordVectorD, this.model.wordVectors, size, this.model.op.trainOptions.regWordVector, true, false);
        this.derivative = NeuralUtils.paramsToVector(dArr.length, modelDerivatives.binaryTD.valueIterator(), modelDerivatives.binaryCD.valueIterator(), SimpleTensor.iteratorSimpleMatrix(modelDerivatives.binaryTensorTD.valueIterator()), modelDerivatives.unaryCD.values().iterator(), modelDerivatives.wordVectorD.values().iterator());
    }

    private static double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap2, double d, double d2, boolean z) {
        double d3 = 0.0d;
        Iterator<TwoDimensionalMap.Entry<String, String, SimpleMatrix>> it = twoDimensionalMap2.iterator();
        while (it.hasNext()) {
            TwoDimensionalMap.Entry<String, String, SimpleMatrix> next = it.next();
            SimpleMatrix simpleMatrix = twoDimensionalMap.get(next.getFirstKey(), next.getSecondKey());
            SimpleMatrix value = next.getValue();
            if (z) {
                value = new SimpleMatrix(value);
                value.insertIntoThis(0, value.numCols() - 1, new SimpleMatrix(value.numRows(), 1));
            }
            twoDimensionalMap.put(next.getFirstKey(), next.getSecondKey(), simpleMatrix.scale(d).plus(value.scale(d2)));
            d3 += (value.elementMult(value).elementSum() * d2) / 2.0d;
        }
        return d3;
    }

    private static double scaleAndRegularize(Map<String, SimpleMatrix> map, Map<String, SimpleMatrix> map2, double d, double d2, boolean z, boolean z2) {
        double d3 = 0.0d;
        for (Map.Entry<String, SimpleMatrix> entry : map2.entrySet()) {
            SimpleMatrix simpleMatrix = map.get(entry.getKey());
            if (z && simpleMatrix == null) {
                map.put(entry.getKey(), new SimpleMatrix(entry.getValue().numRows(), entry.getValue().numCols()));
            } else {
                SimpleMatrix value = entry.getValue();
                if (z2) {
                    value = new SimpleMatrix(value);
                    value.insertIntoThis(0, value.numCols() - 1, new SimpleMatrix(value.numRows(), 1));
                }
                map.put(entry.getKey(), simpleMatrix.scale(d).plus(value.scale(d2)));
                d3 += (value.elementMult(value).elementSum() * d2) / 2.0d;
            }
        }
        return d3;
    }

    private static double scaleAndRegularizeTensor(TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap2, double d, double d2) {
        double d3 = 0.0d;
        Iterator<TwoDimensionalMap.Entry<String, String, SimpleTensor>> it = twoDimensionalMap2.iterator();
        while (it.hasNext()) {
            TwoDimensionalMap.Entry<String, String, SimpleTensor> next = it.next();
            twoDimensionalMap.put(next.getFirstKey(), next.getSecondKey(), twoDimensionalMap.get(next.getFirstKey(), next.getSecondKey()).scale(d).plus(next.getValue().scale(d2)));
            d3 += (next.getValue().elementMult(next.getValue()).elementSum() * d2) / 2.0d;
        }
        return d3;
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap2, TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap3, Map<String, SimpleMatrix> map, Map<String, SimpleMatrix> map2) {
        backpropDerivativesAndError(tree, twoDimensionalMap, twoDimensionalMap2, twoDimensionalMap3, map, map2, new SimpleMatrix(this.model.op.numHid, 1));
    }

    private void backpropDerivativesAndError(Tree tree, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap2, TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap3, Map<String, SimpleMatrix> map, Map<String, SimpleMatrix> map2, SimpleMatrix simpleMatrix) {
        SimpleMatrix mult;
        if (tree.isLeaf()) {
            return;
        }
        SimpleMatrix nodeVector = RNNCoreAnnotations.getNodeVector(tree);
        String basicCategory = this.model.basicCategory(tree.label().value());
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(this.model.numClasses, 1);
        int goldClass = RNNCoreAnnotations.getGoldClass(tree);
        if (goldClass >= 0) {
            simpleMatrix2.set(goldClass, 1.0d);
        }
        double classWeight = this.model.op.trainOptions.getClassWeight(goldClass);
        SimpleMatrix predictions = RNNCoreAnnotations.getPredictions(tree);
        SimpleMatrix scale = goldClass >= 0 ? predictions.minus(simpleMatrix2).scale(classWeight) : new SimpleMatrix(predictions.numRows(), predictions.numCols());
        SimpleMatrix mult2 = scale.mult(NeuralUtils.concatenateWithBias(nodeVector).transpose());
        RNNCoreAnnotations.setPredictionError(tree, (-NeuralUtils.elementwiseApplyLog(predictions).elementMult(simpleMatrix2).elementSum()) * classWeight);
        if (tree.isPreTerminal()) {
            map.put(basicCategory, map.get(basicCategory).plus(mult2));
            String vocabWord = this.model.getVocabWord(tree.children()[0].label().value());
            SimpleMatrix plus = this.model.getUnaryClassification(basicCategory).transpose().mult(scale).extractMatrix(0, this.model.op.numHid, 0, 1).elementMult(NeuralUtils.elementwiseApplyTanhDerivative(nodeVector)).plus(simpleMatrix);
            SimpleMatrix simpleMatrix3 = map2.get(vocabWord);
            if (simpleMatrix3 == null) {
                map2.put(vocabWord, plus);
                return;
            } else {
                map2.put(vocabWord, simpleMatrix3.plus(plus));
                return;
            }
        }
        String basicCategory2 = this.model.basicCategory(tree.children()[0].label().value());
        String basicCategory3 = this.model.basicCategory(tree.children()[1].label().value());
        if (this.model.op.combineClassification) {
            map.put("", map.get("").plus(mult2));
        } else {
            twoDimensionalMap2.put(basicCategory2, basicCategory3, twoDimensionalMap2.get(basicCategory2, basicCategory3).plus(mult2));
        }
        SimpleMatrix plus2 = this.model.getBinaryClassification(basicCategory2, basicCategory3).transpose().mult(scale).extractMatrix(0, this.model.op.numHid, 0, 1).elementMult(NeuralUtils.elementwiseApplyTanhDerivative(nodeVector)).plus(simpleMatrix);
        SimpleMatrix nodeVector2 = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
        SimpleMatrix nodeVector3 = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
        twoDimensionalMap.put(basicCategory2, basicCategory3, twoDimensionalMap.get(basicCategory2, basicCategory3).plus(plus2.mult(NeuralUtils.concatenateWithBias(nodeVector2, nodeVector3).transpose())));
        if (this.model.op.useTensors) {
            twoDimensionalMap3.put(basicCategory2, basicCategory3, twoDimensionalMap3.get(basicCategory2, basicCategory3).plus(getTensorGradient(plus2, nodeVector2, nodeVector3)));
            mult = computeTensorDeltaDown(plus2, nodeVector2, nodeVector3, this.model.getBinaryTransform(basicCategory2, basicCategory3), this.model.getBinaryTensor(basicCategory2, basicCategory3));
        } else {
            mult = this.model.getBinaryTransform(basicCategory2, basicCategory3).transpose().mult(plus2);
        }
        SimpleMatrix elementwiseApplyTanhDerivative = NeuralUtils.elementwiseApplyTanhDerivative(nodeVector2);
        SimpleMatrix elementwiseApplyTanhDerivative2 = NeuralUtils.elementwiseApplyTanhDerivative(nodeVector3);
        SimpleMatrix extractMatrix = mult.extractMatrix(0, plus2.numRows(), 0, 1);
        SimpleMatrix extractMatrix2 = mult.extractMatrix(plus2.numRows(), plus2.numRows() * 2, 0, 1);
        backpropDerivativesAndError(tree.children()[0], twoDimensionalMap, twoDimensionalMap2, twoDimensionalMap3, map, map2, elementwiseApplyTanhDerivative.elementMult(extractMatrix));
        backpropDerivativesAndError(tree.children()[1], twoDimensionalMap, twoDimensionalMap2, twoDimensionalMap3, map, map2, elementwiseApplyTanhDerivative2.elementMult(extractMatrix2));
    }

    private static SimpleMatrix computeTensorDeltaDown(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2, SimpleMatrix simpleMatrix3, SimpleMatrix simpleMatrix4, SimpleTensor simpleTensor) {
        SimpleMatrix extractMatrix = simpleMatrix4.transpose().mult(simpleMatrix).extractMatrix(0, simpleMatrix.numRows() * 2, 0, 1);
        int numElements = simpleMatrix.getNumElements();
        SimpleMatrix simpleMatrix5 = new SimpleMatrix(numElements * 2, 1);
        SimpleMatrix concatenate = NeuralUtils.concatenate(simpleMatrix2, simpleMatrix3);
        for (int i = 0; i < numElements; i++) {
            simpleMatrix5 = simpleMatrix5.plus(simpleTensor.getSlice(i).plus(simpleTensor.getSlice(i).transpose()).mult(concatenate.scale(simpleMatrix.get(i))));
        }
        return simpleMatrix5.plus(extractMatrix);
    }

    private static SimpleTensor getTensorGradient(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2, SimpleMatrix simpleMatrix3) {
        int numElements = simpleMatrix.getNumElements();
        SimpleTensor simpleTensor = new SimpleTensor(numElements * 2, numElements * 2, numElements);
        SimpleMatrix concatenate = NeuralUtils.concatenate(simpleMatrix2, simpleMatrix3);
        for (int i = 0; i < numElements; i++) {
            simpleTensor.setSlice(i, concatenate.scale(simpleMatrix.get(i)).mult(concatenate.transpose()));
        }
        return simpleTensor;
    }

    public void forwardPropagateTree(Tree tree) {
        SimpleMatrix binaryClassification;
        SimpleMatrix elementwiseApplyTanh;
        if (tree.isLeaf()) {
            throw new ForwardPropagationException("We should not have reached leaves in forwardPropagate");
        }
        if (tree.isPreTerminal()) {
            binaryClassification = this.model.getUnaryClassification(tree.label().value());
            elementwiseApplyTanh = NeuralUtils.elementwiseApplyTanh(this.model.getWordVector(tree.children()[0].label().value()));
        } else {
            if (tree.children().length == 1) {
                throw new ForwardPropagationException("Non-preterminal nodes of size 1 should have already been collapsed");
            }
            if (tree.children().length != 2) {
                StringBuilder sb = new StringBuilder();
                sb.append("SentimentCostAndGradient: Tree not correctly binarized:\n   ");
                sb.append(tree);
                sb.append("\nToo many top level constituents present: ");
                sb.append("(" + tree.value());
                for (Tree tree2 : tree.children()) {
                    sb.append(" (" + tree2.value() + " ...)");
                }
                sb.append(")");
                throw new ForwardPropagationException(sb.toString());
            }
            forwardPropagateTree(tree.children()[0]);
            forwardPropagateTree(tree.children()[1]);
            String value = tree.children()[0].label().value();
            String value2 = tree.children()[1].label().value();
            SimpleMatrix binaryTransform = this.model.getBinaryTransform(value, value2);
            binaryClassification = this.model.getBinaryClassification(value, value2);
            SimpleMatrix nodeVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
            SimpleMatrix nodeVector2 = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
            SimpleMatrix concatenateWithBias = NeuralUtils.concatenateWithBias(nodeVector, nodeVector2);
            elementwiseApplyTanh = this.model.op.useTensors ? NeuralUtils.elementwiseApplyTanh(binaryTransform.mult(concatenateWithBias).plus(this.model.getBinaryTensor(value, value2).bilinearProducts(NeuralUtils.concatenate(nodeVector, nodeVector2)))) : NeuralUtils.elementwiseApplyTanh(binaryTransform.mult(concatenateWithBias));
        }
        SimpleMatrix softmax = NeuralUtils.softmax(binaryClassification.mult(NeuralUtils.concatenateWithBias(elementwiseApplyTanh)));
        int predictedClass = getPredictedClass(softmax);
        if (!(tree.label() instanceof CoreLabel)) {
            log.info("SentimentCostAndGradient: warning: No CoreLabels in nodes: " + tree);
            throw new AssertionError("Expected CoreLabels in the nodes");
        }
        CoreLabel coreLabel = (CoreLabel) tree.label();
        coreLabel.set(RNNCoreAnnotations.Predictions.class, softmax);
        coreLabel.set(RNNCoreAnnotations.PredictedClass.class, Integer.valueOf(predictedClass));
        coreLabel.set(RNNCoreAnnotations.NodeVector.class, elementwiseApplyTanh);
    }
}
