/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.database.rdbms;

import com.healthmarketscience.sqlbuilder.CustomSql;
import com.healthmarketscience.sqlbuilder.InsertQuery;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.loading.Inserter;
import org.linqs.psl.database.rdbms.PredicateInfo;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.model.term.UniqueIntID;
import org.linqs.psl.model.term.UniqueStringID;
import org.linqs.psl.util.ListUtils;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RDBMSInserter
extends Inserter {
    public static final int DEFAULT_PAGE_SIZE = 2500;
    public static final double DEFAULT_EVIDENCE_VALUE = 1.0;
    public static final int DEFAULT_MULTIROW_COUNT = 25;
    private static final Logger log = LoggerFactory.getLogger(RDBMSInserter.class);
    private final RDBMSDataStore dataStore;
    private final PredicateInfo predicateInfo;
    private final Partition partition;
    private final String singleInsertSQL;
    private final String multiInsertSQL;

    public RDBMSInserter(RDBMSDataStore dataStore, PredicateInfo predicateInfo, Partition partition) {
        super(predicateInfo.argumentColumns().size());
        this.dataStore = dataStore;
        this.predicateInfo = predicateInfo;
        this.partition = partition;
        this.singleInsertSQL = this.createSingleInsert();
        this.multiInsertSQL = this.createMultiInsert();
    }

    private String createSingleInsert() {
        InsertQuery sqlBuilder = new InsertQuery(this.predicateInfo.tableName());
        sqlBuilder.addCustomPreparedColumns(new CustomSql("partition_id"));
        sqlBuilder.addCustomPreparedColumns(new CustomSql("value"));
        for (String column : this.predicateInfo.argumentColumns()) {
            sqlBuilder.addCustomPreparedColumns(new CustomSql(column));
        }
        return ((InsertQuery)sqlBuilder.validate()).toString();
    }

    private String createMultiInsert() {
        ArrayList<String> columns = new ArrayList<String>();
        columns.add("partition_id");
        columns.add("value");
        columns.addAll(this.predicateInfo.argumentColumns());
        String placeholders = StringUtils.repeat("?", ", ", columns.size());
        ArrayList<String> multiInsert = new ArrayList<String>();
        multiInsert.add("INSERT INTO " + this.predicateInfo.tableName());
        multiInsert.add("    (" + ListUtils.join(", ", columns) + ")");
        multiInsert.add("VALUES");
        multiInsert.add("    " + StringUtils.repeat("(" + placeholders + ")", ", ", 25));
        return ListUtils.join("\n", multiInsert);
    }

    @Override
    public void insertAll(List<List<Object>> data) {
        ArrayList<Double> truthValues = new ArrayList<Double>(data.size());
        for (int i = 0; i < data.size(); ++i) {
            truthValues.add(1.0);
        }
        this.insertInternal(truthValues, data);
    }

    @Override
    public void insertAllValues(List<Double> values, List<List<Object>> data) {
        this.insertInternal(values, data);
    }

    @Override
    public boolean supportsBulkCopy() {
        return this.dataStore.getDriver().supportsBulkCopy();
    }

    @Override
    public void bulkCopy(String path, String delimiter, boolean hasTruth) {
        this.dataStore.getDriver().bulkCopy(path, delimiter, hasTruth, this.predicateInfo, this.partition);
    }

    private void insertInternal(List<Double> values, List<List<Object>> data) {
        assert (values.size() == data.size());
        int partitionID = this.partition.getID();
        if (partitionID < 0) {
            throw new IllegalArgumentException("Partition IDs must be non-negative.");
        }
        for (int rowIndex = 0; rowIndex < data.size(); ++rowIndex) {
            List<Object> row = data.get(rowIndex);
            assert (row != null);
            if (row.size() == this.predicateInfo.argumentColumns().size()) continue;
            throw new IllegalArgumentException(String.format("Data on row %d length does not match for %s: Expecting: %d, Got: %d", rowIndex, this.partition.getName(), this.predicateInfo.argumentColumns().size(), row.size()));
        }
        try (Connection connection = this.dataStore.getConnection();
             PreparedStatement multiInsertStatement = connection.prepareStatement(this.multiInsertSQL);
             PreparedStatement singleInsertStatement = connection.prepareStatement(this.singleInsertSQL);){
            int batchSize = 0;
            PreparedStatement activeStatement = multiInsertStatement;
            int insertSize = 25;
            int rowIndex = 0;
            while (rowIndex < data.size()) {
                int paramIndex = 1;
                if (activeStatement == multiInsertStatement && data.size() - rowIndex < 25) {
                    if (batchSize > 0) {
                        activeStatement.executeBatch();
                        activeStatement.clearBatch();
                        batchSize = 0;
                    }
                    activeStatement = singleInsertStatement;
                    insertSize = 1;
                }
                for (int i = 0; i < insertSize; ++i) {
                    List<Object> row = data.get(rowIndex);
                    Double value = values.get(rowIndex);
                    activeStatement.setInt(paramIndex++, partitionID);
                    if (value == null || value.isNaN()) {
                        activeStatement.setNull(paramIndex++, 8);
                    } else {
                        activeStatement.setDouble(paramIndex++, value);
                    }
                    for (int argIndex = 0; argIndex < this.predicateInfo.argumentColumns().size(); ++argIndex) {
                        Object argValue = row.get(argIndex);
                        assert (argValue != null);
                        if (argValue instanceof Integer) {
                            activeStatement.setInt(paramIndex++, (Integer)argValue);
                            continue;
                        }
                        if (argValue instanceof Double) {
                            if (Double.isNaN((Double)argValue)) {
                                activeStatement.setNull(paramIndex++, 8);
                                continue;
                            }
                            activeStatement.setDouble(paramIndex++, (Double)argValue);
                            continue;
                        }
                        if (argValue instanceof String) {
                            activeStatement.setObject(paramIndex++, this.convertString((String)argValue, argIndex));
                            continue;
                        }
                        if (argValue instanceof UniqueIntID) {
                            activeStatement.setInt(paramIndex++, ((UniqueIntID)argValue).getID());
                            continue;
                        }
                        if (argValue instanceof UniqueStringID) {
                            activeStatement.setString(paramIndex++, ((UniqueStringID)argValue).getID());
                            continue;
                        }
                        throw new IllegalArgumentException("Unknown data type for :" + argValue);
                    }
                    ++rowIndex;
                }
                activeStatement.addBatch();
                if (++batchSize < 2500) continue;
                activeStatement.executeBatch();
                activeStatement.clearBatch();
                batchSize = 0;
            }
            if (batchSize > 0) {
                activeStatement.executeBatch();
                activeStatement.clearBatch();
                batchSize = 0;
            }
            activeStatement.clearParameters();
            activeStatement = null;
        }
        catch (SQLException ex) {
            log.error(ex.getMessage());
            throw new RuntimeException("Error inserting into RDBMS.", ex);
        }
    }

    private Object convertString(String value, int argumentIndex) {
        switch (this.predicateInfo.predicate().getArgumentType(argumentIndex)) {
            case Double: {
                return new Double(Double.parseDouble(value));
            }
            case Integer: 
            case UniqueIntID: {
                return new Integer(Integer.parseInt(value));
            }
            case String: 
            case UniqueStringID: {
                return value;
            }
            case Long: {
                return new Long(Long.parseLong(value));
            }
        }
        throw new IllegalArgumentException("Unknown argument type: " + (Object)((Object)this.predicateInfo.predicate().getArgumentType(argumentIndex)));
    }
}

