/*
 * Decompiled with CFR 0.152.
 */
package tsg;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.Scanner;
import kernels.NodeSetCollector;
import kernels.NodeSetCollectorSimple;
import kernels.NodeSetCollectorStandard;
import settings.Parameters;
import tsg.Label;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelFreqDouble;
import tsg.TSNodeLabelIndex;
import tsg.TSNodeLabelStructure;
import util.ArgumentReader;
import util.FileUtil;
import util.PrintProgress;
import util.Utility;

public class DOP_IO_Log_MT
extends Thread {
    public static int threads = 1;
    public static int minFreqFragment = 0;
    public static int maxDepthFragment = Integer.MAX_VALUE;
    public static int endCycle = 10;
    public static double deltaLogLikelihoodThreshold = 1.0E-5;
    public static int printProgressEvery = 100;
    String workingDir;
    ArrayList<TSNodeLabelIndex> trainingCorpus;
    Hashtable<TSNodeLabel, double[]> fragmentTableLogFreq;
    Hashtable<Label, double[]> rootTableLogFreq;
    PrintProgress printProgress;
    int treebankSize;

    public DOP_IO_Log_MT(File corpusFile, File fragmentFile, String workingDir) throws Exception {
        this.readTreeBank(corpusFile);
        this.readFragmentsFile(fragmentFile);
        this.workingDir = workingDir;
        this.addCFGfragments();
        this.getRootFreq();
    }

    private void readTreeBank(File trainingFile) throws Exception {
        ArrayList<TSNodeLabel> treebank = TSNodeLabel.getTreebank(trainingFile);
        this.trainingCorpus = new ArrayList();
        for (TSNodeLabel t : treebank) {
            this.trainingCorpus.add(new TSNodeLabelIndex(t));
        }
        this.treebankSize = treebank.size();
        Parameters.reportLineFlush((String)("Corpus size: " + this.treebankSize));
    }

    public void readFragmentsFile(File fragmentFile) throws Exception {
        this.fragmentTableLogFreq = new Hashtable();
        Scanner scan = FileUtil.getScanner(fragmentFile);
        int countFragments = 0;
        int discarded = 0;
        while (scan.hasNextLine()) {
            String line = scan.nextLine();
            if (line.equals("")) continue;
            ++countFragments;
            String[] fragmentFreq = line.split("\t");
            String fragmentString = fragmentFreq[0];
            int freq = Integer.parseInt(fragmentFreq[1]);
            if (freq < minFreqFragment) {
                ++discarded;
                continue;
            }
            TSNodeLabel fragment = new TSNodeLabel(fragmentString, false);
            int depth = fragment.maxDepth();
            if (depth > maxDepthFragment) {
                ++discarded;
                continue;
            }
            double logFreq = Math.log(freq);
            this.fragmentTableLogFreq.put(fragment, new double[]{logFreq});
        }
        Parameters.reportLine((String)("Read " + countFragments + " fragments"));
        Parameters.reportLineFlush((String)("Discarded " + discarded + " (freq < " + minFreqFragment + " || depth >" + maxDepthFragment + ")"));
        scan.close();
    }

    public void printFragmentFreq(File outputFile) {
        PrintWriter pw = FileUtil.getPrintWriter(outputFile);
        for (Map.Entry<TSNodeLabel, double[]> e : this.fragmentTableLogFreq.entrySet()) {
            String fragmentString = e.getKey().toString(false, true);
            double freq = Math.exp(e.getValue()[0]);
            if (freq == 0.0) continue;
            pw.println(String.valueOf(fragmentString) + "\t" + freq);
        }
        pw.close();
    }

    /*
     * WARNING - void declaration
     */
    public void addCFGfragments() throws Exception {
        void var2_5;
        Hashtable ruleTable = new Hashtable();
        for (TSNodeLabel tSNodeLabel : this.trainingCorpus) {
            ArrayList<TSNodeLabel> nodes = tSNodeLabel.collectAllNodes();
            for (TSNodeLabel n : nodes) {
                if (n.isLexical) continue;
                String rule = n.cfgRule();
                Utility.increaseInTableInt(ruleTable, rule);
            }
        }
        Parameters.reportLineFlush((String)("Read " + ruleTable.size() + " CFG fragments"));
        boolean bl = false;
        for (Map.Entry e : ruleTable.entrySet()) {
            TSNodeLabel ruleFragment = new TSNodeLabel("( " + (String)e.getKey() + ")", false);
            if (this.fragmentTableLogFreq.containsKey(ruleFragment)) continue;
            double logFreq = Math.log(((int[])e.getValue())[0]);
            this.fragmentTableLogFreq.put(ruleFragment, new double[]{logFreq});
            ++var2_5;
        }
        Parameters.reportLineFlush((String)("Added " + (int)var2_5 + " CFG fragments"));
    }

    public void getRootFreq() {
        this.rootTableLogFreq = new Hashtable();
        for (Map.Entry<TSNodeLabel, double[]> e : this.fragmentTableLogFreq.entrySet()) {
            Label rootLabel = e.getKey().label;
            double logFreq = e.getValue()[0];
            Utility.increaseInTableDoubleLogArray(this.rootTableLogFreq, rootLabel, logFreq);
        }
        Parameters.reportLineFlush((String)("Built root freq. table: " + this.rootTableLogFreq.size() + " entries."));
    }

    @Override
    public void run() {
        try {
            this.runEM();
        }
        catch (InterruptedException e) {
            Parameters.reportError((String)e.getMessage());
        }
    }

    public void runEM() throws InterruptedException {
        int sentencesPerThreads = this.treebankSize / threads;
        int remainingSentences = this.treebankSize % threads;
        if (remainingSentences != 0) {
            ++sentencesPerThreads;
        }
        int cycle = 0;
        double previousLogLikelihood = Double.NEGATIVE_INFINITY;
        File startFile = new File(String.valueOf(this.workingDir) + "kernelsMUB_CFG_freq_EM_cycle_" + cycle + ".txt");
        this.printFragmentFreq(startFile);
        Parameters.reportLineFlush((String)("Written starting frequencies to file: " + startFile));
        do {
            this.printProgress = new PrintProgress("Iterating Training Corpus:", printProgressEvery, 0);
            EMThreadRunner[] bitParThreadArray = new EMThreadRunner[threads];
            int i = 0;
            while (i < threads) {
                int startIndex = sentencesPerThreads * i;
                EMThreadRunner t = null;
                if (i < threads - 1) {
                    int endIndex = sentencesPerThreads * (i + 1);
                    ArrayList<TSNodeLabelIndex> subtreebank = new ArrayList<TSNodeLabelIndex>(this.trainingCorpus.subList(startIndex, endIndex));
                    t = new EMThreadRunner(subtreebank);
                    t.start();
                } else {
                    ArrayList<TSNodeLabelIndex> subtreebank = new ArrayList<TSNodeLabelIndex>(this.trainingCorpus.subList(startIndex, this.treebankSize));
                    t = new EMThreadRunner(subtreebank);
                    t.run();
                }
                bitParThreadArray[i] = t;
                ++i;
            }
            EMThreadRunner[] eMThreadRunnerArray = bitParThreadArray;
            int t = bitParThreadArray.length;
            int n = 0;
            while (n < t) {
                EMThreadRunner t2 = eMThreadRunnerArray[n];
                t2.join();
                ++n;
            }
            double currentLogLikelihood = 0.0;
            this.fragmentTableLogFreq.clear();
            EMThreadRunner[] eMThreadRunnerArray2 = bitParThreadArray;
            int subtreebank = bitParThreadArray.length;
            int n2 = 0;
            while (n2 < subtreebank) {
                EMThreadRunner t3 = eMThreadRunnerArray2[n2];
                currentLogLikelihood += t3.currentLogLikelihood;
                this.addAll(t3.fragmentTableLogFreqThread);
                ++n2;
            }
            this.getRootFreq();
            this.printProgress.end();
            double deltaLogLikelihood = currentLogLikelihood - previousLogLikelihood;
            Parameters.reportLineFlush((String)("EM cyle " + ++cycle + ". Log-Likelihood: " + currentLogLikelihood + " Delta: " + deltaLogLikelihood));
            if (deltaLogLikelihood <= deltaLogLikelihoodThreshold) break;
            previousLogLikelihood = currentLogLikelihood;
            File outputFile = new File(String.valueOf(this.workingDir) + "kernelsMUB_CFG_freq_EM_cycle_" + cycle + ".txt");
            this.printFragmentFreq(outputFile);
            Parameters.reportLineFlush((String)("Written new frequencies to file: " + outputFile));
        } while (cycle != endCycle);
    }

    private void addAll(Hashtable<TSNodeLabel, double[]> fragmentTableLogFreqThread) {
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableLogFreqThread.entrySet()) {
            Utility.increaseInTableDoubleLogArray(this.fragmentTableLogFreq, e.getKey(), e.getValue()[0]);
        }
    }

    private double updateNewFragmentTableFreq(TSNodeLabelIndex t, Hashtable<TSNodeLabel, double[]> newFragmentTableFreq) {
        NodeSetCollectorSimple setCollector = new NodeSetCollectorSimple();
        HashMap<BitSet, TSNodeLabelFreqDouble> bitSetFreqTable = new HashMap<BitSet, TSNodeLabelFreqDouble>();
        for (Map.Entry<TSNodeLabel, double[]> e : this.fragmentTableLogFreq.entrySet()) {
            DOP_IO_Log_MT.getCFGSetCoveringFragment(t, e.getKey(), e.getValue()[0], (NodeSetCollector)setCollector, bitSetFreqTable);
        }
        TSNodeLabelStructure tStructure = new TSNodeLabelStructure(t);
        IOChart pc = new IOChart(setCollector, tStructure, bitSetFreqTable);
        pc.buildChart();
        pc.updateNewFragmentLogFreq(newFragmentTableFreq);
        return pc.getInsideLogProb();
    }

    private static void getCFGSetCoveringFragment(TSNodeLabelIndex t, TSNodeLabel fragment, double fragmentFreq, NodeSetCollector setCollector, HashMap<BitSet, TSNodeLabelFreqDouble> bitSetFreqLogTable) {
        BitSet set;
        if (t.isLexical) {
            return;
        }
        if (t.sameLabel(fragment) && DOP_IO_Log_MT.getCFGSetCoveringFragmentNonRecursive(t, fragment, set = new BitSet()) && !set.isEmpty()) {
            setCollector.add(set);
            bitSetFreqLogTable.put(set, new TSNodeLabelFreqDouble(fragment, fragmentFreq));
        }
        TSNodeLabel[] tSNodeLabelArray = t.daughters;
        int n = t.daughters.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabel d = tSNodeLabelArray[n2];
            TSNodeLabelIndex di = (TSNodeLabelIndex)d;
            DOP_IO_Log_MT.getCFGSetCoveringFragment(di, fragment, fragmentFreq, setCollector, bitSetFreqLogTable);
            ++n2;
        }
    }

    private static boolean getCFGSetCoveringFragmentNonRecursive(TSNodeLabelIndex t, TSNodeLabel fragment, BitSet set) {
        if (t.isLexical || fragment.isTerminal()) {
            return true;
        }
        if (!t.sameDaughtersLabel(fragment)) {
            return false;
        }
        int prole = t.prole();
        int i = 0;
        while (i < prole) {
            TSNodeLabel thisDaughter = t.daughters[i];
            TSNodeLabelIndex thisDaughterIndex = (TSNodeLabelIndex)thisDaughter;
            TSNodeLabel otherDaughter = fragment.daughters[i];
            if (!DOP_IO_Log_MT.getCFGSetCoveringFragmentNonRecursive(thisDaughterIndex, otherDaughter, set)) {
                return false;
            }
            ++i;
        }
        set.set(t.index);
        return true;
    }

    public synchronized void printProgressNext() {
        this.printProgress.next(printProgressEvery);
    }

    public static void main(String[] args) throws Exception {
        minFreqFragment = 1;
        maxDepthFragment = Integer.MAX_VALUE;
        deltaLogLikelihoodThreshold = 0.0;
        File corpusFile = ArgumentReader.readFileOption(args[0]);
        File fragmentFile = ArgumentReader.readFileOption(args[1]);
        String workingDir = new File(args[2]) + "/";
        threads = ArgumentReader.readIntOption(args[3]);
        minFreqFragment = ArgumentReader.readIntOption(args[4]);
        maxDepthFragment = ArgumentReader.readIntOption(args[5]);
        endCycle = ArgumentReader.readIntOption(args[6]);
        File workingDirFile = new File(workingDir);
        if (workingDirFile.exists()) {
            workingDirFile = new File(String.valueOf(workingDir) + FileUtil.dateTimeString() + "/");
        }
        workingDirFile.mkdir();
        Parameters.openLogFile((File)new File(String.valueOf(workingDir) + "log.txt"));
        Parameters.reportLine((String)("Working Dir: " + workingDir));
        Parameters.reportLine((String)("Fragment File: " + fragmentFile));
        Parameters.reportLineFlush((String)("Corpus File: " + corpusFile));
        Parameters.reportLine((String)("threads: " + threads));
        Parameters.reportLine((String)("minFreqFragment: " + minFreqFragment));
        Parameters.reportLine((String)("maxDepthFragment: " + maxDepthFragment));
        Parameters.reportLine((String)("endCycle: " + endCycle));
        Parameters.reportLineFlush((String)("deltaLogLikelihoodThreshold: " + deltaLogLikelihoodThreshold));
        new DOP_IO_Log_MT(corpusFile, fragmentFile, workingDir).run();
    }

    protected class EMThreadRunner
    extends Thread {
        ArrayList<TSNodeLabelIndex> subTreebank;
        Hashtable<TSNodeLabel, double[]> fragmentTableLogFreqThread;
        double currentLogLikelihood;

        public EMThreadRunner(ArrayList<TSNodeLabelIndex> subTreebank) {
            this.subTreebank = subTreebank;
            this.fragmentTableLogFreqThread = new Hashtable();
            this.currentLogLikelihood = 0.0;
        }

        @Override
        public void run() {
            int i = 0;
            for (TSNodeLabelIndex t : this.subTreebank) {
                if (++i == printProgressEvery) {
                    DOP_IO_Log_MT.this.printProgressNext();
                    i = 0;
                }
                double logInsideProb = DOP_IO_Log_MT.this.updateNewFragmentTableFreq(t, this.fragmentTableLogFreqThread);
                this.currentLogLikelihood += logInsideProb;
            }
        }
    }

    class IOChart {
        NodeSetCollectorSimple setCollector;
        TSNodeLabelStructure t;
        int totalNodes;
        IOSubNode[] IOSubNodesChart;
        NodeSetCollectorStandard[] nodesCollector;
        HashMap<BitSet, TSNodeLabelFreqDouble> bitSetFreqTable;

        public IOChart(NodeSetCollectorSimple setCollector, TSNodeLabelStructure t, HashMap<BitSet, TSNodeLabelFreqDouble> bitSetFreqTable) {
            this.setCollector = setCollector;
            this.t = t;
            this.totalNodes = t.length;
            this.IOSubNodesChart = new IOSubNode[this.totalNodes];
            this.nodesCollector = new NodeSetCollectorStandard[this.totalNodes];
            this.bitSetFreqTable = bitSetFreqTable;
        }

        public void buildChart() {
            for (BitSet bs : this.setCollector.bitSetSet) {
                int firstIndex = bs.nextSetBit(0);
                if (this.nodesCollector[firstIndex] == null) {
                    this.nodesCollector[firstIndex] = new NodeSetCollectorStandard();
                }
                this.nodesCollector[firstIndex].add(bs);
            }
            this.buildInsideProb();
            this.buildOutsideProb();
        }

        public double getInsideLogProb() {
            return this.IOSubNodesChart[0].insideLogProb;
        }

        private void buildInsideProb() {
            int i = this.totalNodes - 1;
            while (i >= 0) {
                if (!this.t.structure[i].isLexical) {
                    this.buildInsideProb(i);
                }
                --i;
            }
        }

        private void buildInsideProb(int index) {
            NodeSetCollectorStandard setCollector = this.nodesCollector[index];
            TSNodeLabelIndex root = this.t.structure[index];
            double rootLogFreq = DOP_IO_Log_MT.this.rootTableLogFreq.get(root.label)[0];
            IOSubNode IOSubNodeIndex = new IOSubNode();
            for (BitSet initialSubTree : setCollector.bitSetArray) {
                ArrayList<Integer> subSitesIndexes = new ArrayList<Integer>();
                this.collectSubSites(root, initialSubTree, subSitesIndexes);
                double initialSubTreeInsideLogProb = 0.0;
                for (int subSiteIndex : subSitesIndexes) {
                    double subSiteInsideLogProb = this.IOSubNodesChart[subSiteIndex].insideLogProb;
                    initialSubTreeInsideLogProb += subSiteInsideLogProb;
                }
                TSNodeLabelFreqDouble treeDouble = this.bitSetFreqTable.get(initialSubTree);
                double initialSubTreeFreq = treeDouble.freq;
                TSNodeLabel initialFragment = treeDouble.tree;
                IOSubNodeIndex.addDerivation(initialFragment, subSitesIndexes, initialSubTreeInsideLogProb += initialSubTreeFreq - rootLogFreq);
            }
            this.IOSubNodesChart[index] = IOSubNodeIndex;
        }

        private void collectSubSites(TSNodeLabelIndex root, BitSet initialSubTree, ArrayList<Integer> subSitesIndexes) {
            TSNodeLabel[] tSNodeLabelArray = root.daughters;
            int n = root.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNodeLabel d = tSNodeLabelArray[n2];
                if (d.isLexical) {
                    return;
                }
                TSNodeLabelIndex di = (TSNodeLabelIndex)d;
                int index = di.index;
                if (!initialSubTree.get(index)) {
                    subSitesIndexes.add(index);
                } else {
                    this.collectSubSites(di, initialSubTree, subSitesIndexes);
                }
                ++n2;
            }
        }

        private void buildOutsideProb() {
            this.IOSubNodesChart[0].outsideLogProb = 0.0;
            int i = 0;
            while (i < this.totalNodes) {
                if (!this.t.structure[i].isLexical) {
                    this.buildOutsideProb(i);
                }
                ++i;
            }
        }

        private void buildOutsideProb(int index) {
            IOSubNode IOSubNodeIndex = this.IOSubNodesChart[index];
            double ousideLogProb = IOSubNodeIndex.outsideLogProb;
            for (InitFragmentDerivations ifd : IOSubNodeIndex.partialDerivations) {
                double initialFragmInsideLogProb = ifd.initFragmentInsideLogProb;
                for (int subSite : ifd.subSites) {
                    double subSiteInsideLogProb = this.IOSubNodesChart[subSite].insideLogProb;
                    double outsideLogProbToAdd = ousideLogProb + initialFragmInsideLogProb - subSiteInsideLogProb;
                    IOSubNode subSiteDerivation = this.IOSubNodesChart[subSite];
                    subSiteDerivation.addOutisdeProb(outsideLogProbToAdd);
                }
            }
        }

        public void updateNewFragmentLogFreq(Hashtable<TSNodeLabel, double[]> newFragmentTableFreq) {
            double insideLogProbTOP = this.IOSubNodesChart[0].insideLogProb;
            int i = 0;
            while (i < this.totalNodes) {
                if (!this.t.structure[i].isLexical) {
                    IOSubNode IOSubNodeIndex = this.IOSubNodesChart[i];
                    double ousideLogProb = IOSubNodeIndex.outsideLogProb;
                    for (InitFragmentDerivations ifd : IOSubNodeIndex.partialDerivations) {
                        TSNodeLabel initialFragment = ifd.intialFragment;
                        double initialFragmInsideLogProb = ifd.initFragmentInsideLogProb;
                        double newFreqToAdd = ousideLogProb + initialFragmInsideLogProb - insideLogProbTOP;
                        Utility.increaseInTableDoubleLogArray(newFragmentTableFreq, initialFragment, newFreqToAdd);
                    }
                }
                ++i;
            }
        }
    }

    class IOSubNode {
        ArrayList<InitFragmentDerivations> partialDerivations = new ArrayList();
        double insideLogProb = Double.NEGATIVE_INFINITY;
        double outsideLogProb = Double.NEGATIVE_INFINITY;

        public void addDerivation(TSNodeLabel intialFragment, ArrayList<Integer> subSites, double initFragmInsideLogProb) {
            this.partialDerivations.add(new InitFragmentDerivations(intialFragment, subSites, initFragmInsideLogProb));
            this.insideLogProb = Utility.logSum(this.insideLogProb, initFragmInsideLogProb);
        }

        public void addOutisdeProb(double outsideLogProbToAdd) {
            this.outsideLogProb = Utility.logSum(this.outsideLogProb, outsideLogProbToAdd);
        }
    }

    class InitFragmentDerivations {
        TSNodeLabel intialFragment;
        ArrayList<Integer> subSites;
        double initFragmentInsideLogProb;

        public InitFragmentDerivations(TSNodeLabel intialFragment, ArrayList<Integer> subSites, double initFragmentInsideLogProb) {
            this.intialFragment = intialFragment;
            this.subSites = subSites;
            this.initFragmentInsideLogProb = initFragmentInsideLogProb;
        }
    }
}

