/*
 * Decompiled with CFR 0.152.
 */
package tsg.LTSG;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import settings.Parameters;
import tsg.CFSG;
import tsg.LTSG.LDQQueue;
import tsg.LTSG.LTSG;
import tsg.LTSG.LexDerivationQueue;
import tsg.LTSG.LexicalDerivation;
import tsg.TSNode;
import tsg.parser.Parser;
import util.FileUtil;
import util.Utility;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LTSG_EMProb
extends CFSG<Double> {
    Hashtable<String, Integer> template_freq = new Hashtable();
    Hashtable<String, Integer> lexicon_freq = new Hashtable();
    Hashtable<String, Double> template_prob = new Hashtable();
    Hashtable<String, Double> root_prob = new Hashtable();

    public void toPCFG() {
        if (Parameters.smoothing) {
            for (TSNode TreeLine : Parameters.trainingCorpus.treeBank) {
                TreeLine.toNormalForm();
                List<TSNode> nonLexicalNodes = TreeLine.collectNonLexicalNodes();
                for (TSNode nonTerminal : nonLexicalNodes) {
                    Hashtable toAdd = nonTerminal.isPrelexical() ? this.lexRules : this.internalRules;
                    String rule = nonTerminal.toCFG(false);
                    Utility.increaseStringDouble(toAdd, rule, 1.0);
                }
            }
        }
        int uniqueLableIndex = 1;
        Enumeration<String> e = this.template_prob.keys();
        while (e.hasMoreElements()) {
            String eTree = e.nextElement();
            Double count = this.template_prob.get(eTree);
            TSNode TN = new TSNode(eTree, false);
            TN.toNormalForm();
            uniqueLableIndex = TN.toUniqueInternalLabels(false, uniqueLableIndex, false);
            List<TSNode> nonTerminals = TN.collectNonTerminalNodes();
            for (TSNode nonTerminal : nonTerminals) {
                String rule = nonTerminal.toCFG(false);
                Hashtable toAdd = nonTerminal.isPrelexical() ? this.lexRules : this.internalRules;
                Utility.increaseStringDouble(toAdd, rule, count * (double)Parameters.smoothingFactor);
            }
        }
        String log = "Converted trees to PCFG (smoothing = " + Parameters.smoothing + ")" + "\n\t# Internal Rules: " + this.internalRules.size() + "\n\t# Lex Rules: " + this.lexRules.size();
        FileUtil.append(log, Parameters.logFile);
    }

    public void removeZeroProbTrees() {
        int templatesSizeBefore = this.template_prob.size();
        Iterator<Map.Entry<String, Double>> i = this.template_prob.entrySet().iterator();
        while (i.hasNext()) {
            Map.Entry<String, Double> entry = i.next();
            Double count = entry.getValue();
            if (count != 0.0) continue;
            i.remove();
        }
        String log = "Removing templates with zero probability\n\t# Templates # before removal: " + templatesSizeBefore + "\n\t# Templates # after removal: " + this.template_prob.size();
        FileUtil.append(log, Parameters.logFile);
    }

    public void printTemplatesToFile() {
        File templatesFile = new File(String.valueOf(Parameters.outputPath) + "TemplatesFile");
        try {
            PrintWriter grammar = new PrintWriter(new BufferedWriter(new FileWriter(templatesFile)));
            Enumeration<String> e = this.template_prob.keys();
            while (e.hasMoreElements()) {
                String tree = e.nextElement();
                Double count = this.template_prob.get(tree);
                String line = String.valueOf(count.toString()) + " " + tree;
                grammar.write(String.valueOf(line) + "\n");
            }
            grammar.close();
        }
        catch (Exception e) {
            FileUtil.handleExceptions(e);
        }
        String log = "Printed templates to file `templatesFile`";
        FileUtil.append(log, Parameters.logFile);
    }

    public void readTreesFromFile(File templateFile) {
        this.template_prob.clear();
        this.root_prob.clear();
        Scanner scan = FileUtil.getScanner(templateFile);
        while (scan.hasNextLine()) {
            Double prob = scan.nextDouble();
            String tree = scan.nextLine().trim();
            this.template_prob.put(tree, prob);
            Utility.increaseStringDouble(this.root_prob, TSNode.get_unique_root(tree), prob);
        }
    }

    private void extractRootProb() {
        this.root_prob = new Hashtable();
        for (Map.Entry<String, Double> e : this.template_prob.entrySet()) {
            String root = TSNode.get_unique_root(e.getKey());
            Double count = e.getValue();
            Utility.increaseStringDouble(this.root_prob, root, count);
        }
    }

    private void initializeUniformTreeProb() {
        this.extractAllLexTrees();
        this.template_prob = new Hashtable();
        for (Map.Entry<String, Integer> e : this.template_freq.entrySet()) {
            String tree = e.getKey();
            double count = e.getValue().intValue();
            this.template_prob.put(tree, count);
        }
        this.normalizeTemplateProb();
    }

    private void normalizeTemplateProb() {
        this.extractRootProb();
        for (Map.Entry<String, Double> e : this.template_prob.entrySet()) {
            String tree = e.getKey();
            double treeProb = e.getValue();
            String root = TSNode.get_unique_root(tree);
            double rootCount = this.root_prob.get(root);
            double newTreeProb = treeProb / rootCount;
            e.setValue(newTreeProb);
        }
    }

    public void reportMaxLexicalDerivations() {
        long maxDerivation = 0L;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            long lexDerivations = inputTree.lexDerivations();
            if (lexDerivations <= maxDerivation) continue;
            maxDerivation = lexDerivations;
        }
        FileUtil.append("Max number of derivation per tree: " + maxDerivation, Parameters.logFile);
    }

    public void EMalgorithm() {
        this.initializeUniformTreeProb();
        int cycle = 0;
        double previousLikelihood = -1.7976931348623157E308;
        double delta = 0.0;
        do {
            double actualLikelihood = this.emSteps();
            delta = actualLikelihood - previousLikelihood;
            previousLikelihood = actualLikelihood;
            String line = "EM cycle: " + ++cycle + "\tActual LikeLihood: " + actualLikelihood + "\tDelta LikeLihood: " + delta;
            FileUtil.append(line, Parameters.logFile);
        } while (delta > 0.0 && delta > Parameters.EM_deltaThreshold);
    }

    private void subtractProbability(TSNode tree) {
        List<TSNode> anchors = tree.collectLexicalItems();
        List<ArrayList<TSNode>> eTrees = LTSG.allLexTreesForEachAnchor(tree, false, anchors);
        for (ArrayList<TSNode> terminalTrees : eTrees) {
            for (TSNode lexTree : terminalTrees) {
                Utility.decreaseStringInteger(this.template_freq, lexTree.toString(false, true), 1);
            }
        }
    }

    private void addProbability(TSNode tree) {
        List<TSNode> anchors = tree.collectLexicalItems();
        List<ArrayList<TSNode>> eTrees = LTSG.allLexTreesForEachAnchor(tree, false, anchors);
        for (ArrayList<TSNode> terminalTrees : eTrees) {
            for (TSNode lexTree : terminalTrees) {
                Utility.increaseStringInteger(this.template_freq, lexTree.toString(false, true), 1);
            }
        }
    }

    private double emSteps() {
        double likelihood = 0.0;
        Hashtable<String, Double> new_template_prob = new Hashtable<String, Double>();
        int treeIndex = 0;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            ++treeIndex;
            inputTree.removeHeadAnnotations();
            likelihood += this.getNBestHeadAnnotations(inputTree, new_template_prob);
            if (!inputTree.hasWrongHeadAssignment()) continue;
            System.err.println("Wrong Head Assignment: " + inputTree.toString(true, true));
        }
        this.template_prob = new_template_prob;
        this.normalizeTemplateProb();
        return likelihood;
    }

    private void assignHeadAnnotation(TSNode inputTree, IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees, int index) {
        LexicalDerivation lexD = nBestDerivationsSubTrees.get(inputTree)[index];
        int[] indexes = lexD.subSiteDerivationsIndexes;
        TSNode anchor = lexD.anchor;
        TSNode lexiconPath = anchor.parent;
        TSNode lexiconPathDaughter = anchor;
        int substitutionIndex = 0;
        while (lexiconPath != inputTree) {
            lexiconPath.headMarked = true;
            lexiconPath = lexiconPath.parent;
            lexiconPathDaughter = lexiconPathDaughter.parent;
            TSNode[] tSNodeArray = lexiconPath.daughters;
            int n = lexiconPath.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNode D = tSNodeArray[n2];
                if (D != lexiconPathDaughter) {
                    this.assignHeadAnnotation(D, nBestDerivationsSubTrees, indexes[substitutionIndex++]);
                }
                ++n2;
            }
            if (lexiconPath != inputTree) continue;
        }
    }

    private LexicalDerivation[] getNBestTable(TSNode TN, IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees) {
        TSNode[] lexicalsTN = TN.collectTerminals().toArray(new TSNode[0]);
        double[] lexTreeWeights = new double[lexicalsTN.length];
        LexicalDerivation[][][] lex_SubSite_NBestTable = new LexicalDerivation[lexicalsTN.length][][];
        int nonNullLex = 0;
        int lexicalIndex = -1;
        TSNode[] tSNodeArray = lexicalsTN;
        int n = lexicalsTN.length;
        int n2 = 0;
        while (n2 < n) {
            TSNode anchor = tSNodeArray[n2];
            ++lexicalIndex;
            int substitutionSites = 0;
            TSNode lexiconPath = anchor.parent;
            while (lexiconPath != TN) {
                substitutionSites += lexiconPath.prole() - 1;
                lexiconPath.headMarked = true;
                lexiconPath = lexiconPath.parent;
            }
            substitutionSites += lexiconPath.prole() - 1;
            String lexicalTree = TN.lexicalizedTreeCopy().toString(false, true);
            Double weight = this.template_prob.get(lexicalTree);
            if (weight == null) {
                TN.unmarkHeadPathToAnchor(anchor);
                lexTreeWeights[lexicalIndex] = -1.0;
                lex_SubSite_NBestTable[lexicalIndex] = null;
            } else {
                double logWeight;
                lexTreeWeights[lexicalIndex] = logWeight = Math.log(weight);
                lexiconPath = anchor.parent;
                lexiconPath.headMarked = false;
                TSNode lexiconPathDaughter = anchor;
                lex_SubSite_NBestTable[lexicalIndex] = new LexicalDerivation[substitutionSites][];
                int substitutionIndex = 0;
                boolean nullSubSide = false;
                do {
                    lexiconPath = lexiconPath.parent;
                    lexiconPathDaughter = lexiconPathDaughter.parent;
                    lexiconPath.headMarked = false;
                    TSNode[] tSNodeArray2 = lexiconPath.daughters;
                    int n3 = lexiconPath.daughters.length;
                    int n4 = 0;
                    while (n4 < n3) {
                        TSNode D = tSNodeArray2[n4];
                        if (D != lexiconPathDaughter) {
                            lex_SubSite_NBestTable[lexicalIndex][substitutionIndex] = nBestDerivationsSubTrees.get(D);
                            if (lex_SubSite_NBestTable[lexicalIndex][substitutionIndex] == null) {
                                nullSubSide = true;
                            }
                            ++substitutionIndex;
                        }
                        ++n4;
                    }
                } while (lexiconPath != TN);
                if (nullSubSide) {
                    lexTreeWeights[lexicalIndex] = -1.0;
                    lex_SubSite_NBestTable[lexicalIndex] = null;
                } else {
                    ++nonNullLex;
                }
            }
            ++n2;
        }
        return LTSG_EMProb.computeNBestTable(nBestDerivationsSubTrees, lex_SubSite_NBestTable, nonNullLex, lexicalsTN, lexTreeWeights);
    }

    private static LexicalDerivation[] computeNBestTable(IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees, LexicalDerivation[][][] lex_SubSite_NBestTable, int nonNullLex, TSNode[] lexicalsTN, double[] lexTreeWeights) {
        if (nonNullLex == 0) {
            return null;
        }
        LexDerivationQueue[] lexQueues = new LexDerivationQueue[lexicalsTN.length];
        int totalCombination = 0;
        int i = 0;
        while (i < lexicalsTN.length) {
            if (lex_SubSite_NBestTable[i] != null) {
                lexQueues[i] = new LexDerivationQueue(lex_SubSite_NBestTable[i], lexicalsTN[i], i, lexTreeWeights[i]);
                totalCombination += lexQueues[i].combinations;
            }
            ++i;
        }
        LDQQueue superQueue = new LDQQueue(lexQueues);
        int nBestTableSize = Math.min(Parameters.EM_nBest, totalCombination);
        if (nBestTableSize < 0) {
            nBestTableSize = Parameters.EM_nBest;
        }
        LexicalDerivation[] nBestTable = new LexicalDerivation[nBestTableSize];
        nBestTable[0] = superQueue.pollFirst();
        int i2 = 1;
        while (i2 < nBestTableSize) {
            nBestTable[i2] = superQueue.addNeighboursAndPoll();
            ++i2;
        }
        return nBestTable;
    }

    private double getNBestHeadAnnotations(TSNode inputTree, Hashtable<String, Double> new_template_prob) {
        ArrayList<ArrayList<TSNode>> levelsSubTrees = inputTree.getNodesInDepthLevels();
        IdentityHashMap<TSNode, LexicalDerivation[]> nBestDerivationsSubTrees = new IdentityHashMap<TSNode, LexicalDerivation[]>();
        for (ArrayList arrayList : levelsSubTrees) {
            for (TSNode TN : arrayList) {
                LexicalDerivation[] nBestTable;
                if (TN.isLexical || TN.isUniqueDaughter()) continue;
                if (TN.isPrelexical() || !TN.hasMoreThanNBranching(1)) {
                    Double weight = this.template_prob.get(TN.toString(false, true));
                    if (weight == null) continue;
                    double logWeight = Math.log(weight);
                    nBestTable = new LexicalDerivation[]{new LexicalDerivation(TN.getAnchor(), 0, logWeight, null, null)};
                } else {
                    nBestTable = this.getNBestTable(TN, nBestDerivationsSubTrees);
                }
                nBestDerivationsSubTrees.put(TN, nBestTable);
            }
        }
        double d = 0.0;
        LexicalDerivation[] nBestTable = (LexicalDerivation[])nBestDerivationsSubTrees.get(inputTree);
        int i = 0;
        while (i < nBestTable.length) {
            d += Math.exp(nBestTable[i].logDerivationProb);
            ++i;
        }
        --i;
        while (i > -1) {
            inputTree.removeHeadAnnotations();
            this.assignHeadAnnotation(inputTree, nBestDerivationsSubTrees, i);
            List<TSNode> eTrees = inputTree.lexicalizedTreesFromHeadAnnotation();
            for (TSNode lexTree : eTrees) {
                double newWeight = Math.exp(nBestTable[i].logDerivationProb) / d;
                Utility.increaseStringDouble(new_template_prob, lexTree.toString(false, true), newWeight);
            }
            --i;
        }
        return Math.log(d);
    }

    public void extractAllLexTrees() {
        Parameters.trainingCorpus.removeHeadAnnotations();
        int i = -1;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            ++i;
            List<TSNode> lexicon = inputTree.collectLexicalItems();
            List<ArrayList<TSNode>> eTrees = LTSG.allLexTreesForEachAnchor(inputTree, Parameters.spineConversion, lexicon);
            int terminalIndex = 0;
            for (ArrayList<TSNode> terminalTrees : eTrees) {
                String lexAnchor = lexicon.get(terminalIndex).label();
                for (TSNode lexTree : terminalTrees) {
                    Utility.increaseStringInteger(this.template_freq, lexTree.toString(false, true), 1);
                    Utility.increaseStringInteger(this.lexicon_freq, lexAnchor, 1);
                }
                ++terminalIndex;
            }
        }
        String log = "Extracted all possible lexicalized trees\n\t# Trees: " + this.template_freq.size();
        FileUtil.append(log, Parameters.logFile);
    }

    public static int[][] bestNcombinations(Double[][] substitutionSiteWeightLists, int n) {
        int elements = substitutionSiteWeightLists.length;
        int[] indexes = new int[elements];
        Arrays.fill(indexes, 1);
        int indexMax = 0;
        while (Utility.product(indexes) < n && indexMax != -1) {
            double max = 0.0;
            indexMax = -1;
            int i = 0;
            while (i < elements) {
                double nextWeightElementDouble;
                Double nextWeightElement = substitutionSiteWeightLists[i][indexes[i]];
                if (nextWeightElement != null && (nextWeightElementDouble = nextWeightElement.doubleValue()) > max) {
                    max = nextWeightElementDouble;
                    indexMax = i;
                }
                ++i;
            }
            if (indexMax == -1) continue;
            int n2 = indexMax;
            indexes[n2] = indexes[n2] + 1;
        }
        return Utility.combinations(indexes);
    }

    public static void main(String[] args) {
        Parameters.setDefaultParam();
        Parameters.smoothing = false;
        Parameters.LTSGtype = "EM";
        Parameters.outputPath = "/home/fsangati/PROJECTS/TSG/RESULTS/LTSG/" + Parameters.LTSGtype + "/";
        Parameters.EM_nBest = 100;
        Parameters.EM_deltaThreshold = 0.1;
        Parameters.parserName = "fedePar";
        Parameters.nBest = 1;
        Parameters.cachingActive = true;
        LTSG_EMProb Grammar2 = new LTSG_EMProb();
        Grammar2.reportMaxLexicalDerivations();
        Grammar2.EMalgorithm();
        Grammar2.removeZeroProbTrees();
        Grammar2.toPCFG();
        Grammar2.printTemplatesToFile();
        Grammar2.printTrainingCorpusToFile();
        Grammar2.printLexiconAndGrammarFiles();
        new Parser(Grammar2);
    }
}

