package incheon.ags.dss.weight.batch;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.stereotype.Component;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Consumer;

@Component
@RequiredArgsConstructor
@Slf4j
public class AnaWeightBatchProcessor {

    private final JdbcTemplate jdbcTemplate;
    private final NamedParameterJdbcTemplate namedJdbcTemplate;
    private final Executor dssBatchExecutor;

    // TCP 제어 관련 상수
    private static final int INITIAL_CHUNK_SIZE = 10;
    private static final int MIN_CHUNK_SIZE = 1;
    private static final int MAX_CHUNK_SIZE = 2000;
    private static final int CONCURRENT_LIMIT = 5;
    private static final long TARGET_EXECUTION_TIME_MS = 30_000; // 목표: 30초

    // =====================================================================
    // SQL 문은 기존과 동일하므로 생략하거나 그대로 유지
    // =====================================================================
    private static final String SQL_INSERT_MST = """
        INSERT INTO icdss.ana_wghtval_mst (
            sca_scp_cd, meter_scp_cd, km_scp_cd, ldcg, geom, area,
            frst_reg_id, frst_reg_dt, last_mdfcn_id, last_mdfcn_dt
        )
        WITH
            TargetScope AS (
                SELECT r.scp_cd, r.geom
                FROM icdss.ana_stats_data_mst r
                WHERE r.scp_type = 'R'
                  AND r.scp_cd IN (:scpCodes)
            ),
            Final_Spatial_Parcels AS (
                SELECT 
                    r.scp_cd                             AS sca_scp_cd,
                    RIGHT(l.jibun, 1)                    AS ldcg,
                    m.scp_cd                             AS meter_scp_cd,
                    LEFT(m.scp_cd, LENGTH(m.scp_cd) - 2) AS km_scp_cd,
                    ST_Intersection(
                        ST_Intersection(r.geom, m.geom),
                        ST_Buffer(ST_MakeValid(ST_SimplifyPreserveTopology(l.geom, 0.1)), 0)
                    ) AS final_geom
                FROM TargetScope r
                JOIN icdss.ana_stats_data_mst m 
                  ON m.scp_type = 'M' AND ST_Intersects(m.geom, r.geom)
                JOIN icext.lsmd_cont_ldreg l 
                  ON ST_Intersects(l.geom, r.geom) AND ST_Intersects(l.geom, m.geom)
            )
        SELECT 
            vp.sca_scp_cd, vp.meter_scp_cd, vp.km_scp_cd, vp.ldcg, 
            vp.final_geom, ST_Area(vp.final_geom),
            :userId, NOW(), :userId, NOW()
        FROM Final_Spatial_Parcels vp
        WHERE vp.final_geom IS NOT NULL 
          AND NOT ST_IsEmpty(vp.final_geom)
          AND ST_Area(vp.final_geom) > 0.001
          AND vp.ldcg IS NOT NULL
    """;

    private static final String SQL_INSERT_SUM = """
        INSERT INTO icdss.ana_wghtval_sm (
            baye, sca_scp_cd, stats_artcl, tot_wgvl,
            frst_reg_id, frst_reg_dt, last_mdfcn_id, last_mdfcn_dt
        )
        WITH
            TargetScope AS ( SELECT scp_cd FROM icdss.ana_stats_data_mst WHERE scp_cd IN (:scpCodes) ),
            Parcels AS (
                SELECT mst.sca_scp_cd, mst.meter_scp_cd, mst.km_scp_cd, mst.ldcg, mst.area
                FROM icdss.ana_wghtval_mst mst JOIN TargetScope t ON mst.sca_scp_cd = t.scp_cd
            ),
            R_Stats_Types AS (
                SELECT r_dtl.scp_cd AS sca_scp_cd, 
                       r_ty.stats_artcl, r_ty.stats_lclsf_nm, r_ty.prnt_artcl
                FROM icdss.ana_stats_data_dtl r_dtl
                JOIN icdss.ana_stats_data_ty r_ty ON r_dtl.stats_artcl = r_ty.stats_artcl
                JOIN TargetScope t ON r_dtl.scp_cd = t.scp_cd
                WHERE r_dtl.baye = :year AND r_dtl.stats_vl > 0
            ),
            Relevant_K_Stats AS (
                SELECT p.km_scp_cd AS scp_cd, 
                       k_ty.prnt_artcl,
                       SUM(k_dtl.stats_vl) AS stats_vl
                FROM (SELECT DISTINCT km_scp_cd FROM Parcels) p
                JOIN icdss.ana_stats_data_dtl k_dtl ON k_dtl.scp_cd = p.km_scp_cd AND k_dtl.baye = :year
                JOIN icdss.ana_stats_data_ty k_ty ON k_ty.stats_artcl = k_dtl.stats_artcl
                WHERE k_ty.prnt_artcl IN (SELECT DISTINCT prnt_artcl FROM R_Stats_Types)
                GROUP BY p.km_scp_cd, k_ty.prnt_artcl
            ),
            Relevant_M_Stats AS (
                SELECT p.meter_scp_cd AS scp_cd, 
                       m_ty.stats_artcl,
                       SUM(m_dtl.stats_vl) AS stats_vl
                FROM (SELECT DISTINCT meter_scp_cd FROM Parcels) p
                JOIN icdss.ana_stats_data_dtl m_dtl ON m_dtl.scp_cd = p.meter_scp_cd AND m_dtl.baye = :year
                JOIN icdss.ana_stats_data_ty m_ty ON m_ty.stats_artcl = m_dtl.stats_artcl
                WHERE m_ty.stats_artcl IN (SELECT DISTINCT prnt_artcl FROM R_Stats_Types)
                GROUP BY p.meter_scp_cd, m_ty.stats_artcl
            ),
            Aggregated_Denominators AS (
                SELECT 
                    p.sca_scp_cd, rst.stats_artcl,
                    SUM(
                        COALESCE(
                            NULLIF(m.stats_vl, 0),       -- 1순위: 100m 격자 값
                            (k.stats_vl / 100.0),        -- 2순위: 1km 격자 값 (면적비 1:100 보정)
                            0.001                        -- 3순위: 데이터 없음 (최소값)
                        ) * p.area
                    ) AS tot_wgvl
                FROM Parcels p
                JOIN R_Stats_Types rst ON p.sca_scp_cd = rst.sca_scp_cd
                LEFT JOIN Relevant_K_Stats k ON k.scp_cd = p.km_scp_cd AND k.prnt_artcl = rst.prnt_artcl
                LEFT JOIN Relevant_M_Stats m ON m.scp_cd = p.meter_scp_cd AND m.stats_artcl = rst.prnt_artcl
                WHERE
                    CASE
                        WHEN rst.stats_lclsf_nm IN ('인구', '가구', '주택') THEN p.ldcg IN ('대', '장', '학', '잡')
                        WHEN rst.stats_lclsf_nm = '산업' THEN
                            CASE WHEN rst.stats_artcl IN ('cp_bnu_001', 'cp2_bnu_01', 'cp2_bnu_02', 'cp2_bnu_03', 'cp_bem_001', 'cp2_bem_01', 'cp2_bem_02', 'cp2_bem_03')
                                THEN p.ldcg IN ('대', '장', '주', '창', '학', '체', '원', '종', '전', '답', '과', '목', '임', '양')
                                ELSE p.ldcg IN ('대', '장', '주', '창', '학', '체', '원', '종')
                            END
                        ELSE p.ldcg IN ('대', '장', '학', '잡')
                    END
                GROUP BY p.sca_scp_cd, rst.stats_artcl
            )
        SELECT 
            :year, ad.sca_scp_cd, ad.stats_artcl, ad.tot_wgvl,
            :userId, NOW(), :userId, NOW()
        FROM Aggregated_Denominators ad
    """;

    // ... runBatchJob, processMasterLogic, processSumLogic 메서드는 기존 구조 유지 ...
    // (runBatchJob 내에서 processAdaptiveChunks를 호출하도록 수정됨)

    public void runBatchJob(String userId) {
        log.info("===== [BATCH START] 가중치 프로세서 (TCP-like Adaptive Batch) =====");
        long start = System.currentTimeMillis();

        try {
            log.info("    - [Config] Work Mem 증설 & 인덱스 튜닝");
            jdbcTemplate.execute("SET work_mem = '1GB'");
            jdbcTemplate.execute("SET maintenance_work_mem = '1GB'");

            // 1. Master 작업
            processMasterLogic(userId);

            // 2. Sum 작업
            String yearCol = "baye";
            List<Integer> years = jdbcTemplate.queryForList(
                "SELECT DISTINCT " + yearCol + " FROM icdss.ana_stats_data_dtl ORDER BY " + yearCol,
                Integer.class
            );

            for (int year : years) {
                processSumLogic(year, userId);
            }

        } catch (Exception e) {
            log.error("===== [BATCH ERROR] 중단됨 =====", e);
        } finally {
            log.info("===== [BATCH END] 총 소요: {} =====", formatDuration(System.currentTimeMillis() - start));
        }
    }

    private void processMasterLogic(String userId) {
        log.info(">>> [Step 1] Master 데이터 생성");
        try {
            jdbcTemplate.execute("TRUNCATE TABLE icdss.ana_wghtval_mst");
            jdbcTemplate.execute("DROP INDEX IF EXISTS icdss.idx_ana_wghtval_mst_geom");
        } catch (Exception e) { throw new RuntimeException("초기화 실패", e); }

        List<String> codes = jdbcTemplate.queryForList(
            "SELECT scp_cd FROM icdss.ana_stats_data_mst WHERE scp_type = 'R'", String.class
        );
        Collections.shuffle(codes); // 랜덤성을 주어 특정 구역 몰림 방지

        // 변경된 메서드 호출
        processAdaptiveChunks("Mst", codes, chunk -> {
            MapSqlParameterSource params = new MapSqlParameterSource();
            params.addValue("scpCodes", chunk);
            params.addValue("userId", userId);
            namedJdbcTemplate.update(SQL_INSERT_MST, params);
        });

        // 인덱스 생성
        log.info("    - [Info] 인덱스 생성 시작...");
        long idxStart = System.currentTimeMillis();
        try {
            jdbcTemplate.execute("CREATE INDEX idx_ana_wghtval_mst_geom ON icdss.ana_wghtval_mst USING GIST (geom)");
            log.info("    - [Finish] 인덱스 생성 완료 (소요: {})", formatDuration(System.currentTimeMillis() - idxStart));
        } catch (Exception e) { log.warn("인덱스 생성 경고", e); }
    }

    private void processSumLogic(int year, String userId) {
        log.info(">>> [Step 2] {}년도 Sum 생성", year);
        jdbcTemplate.update("DELETE FROM icdss.ana_wghtval_sm WHERE baye = ?", year);

        List<String> codes = jdbcTemplate.queryForList(
            "SELECT DISTINCT sca_scp_cd FROM icdss.ana_wghtval_mst", String.class
        );

        // 변경된 메서드 호출
        processAdaptiveChunks("Sum(" + year + ")", codes, chunk -> {
            MapSqlParameterSource params = new MapSqlParameterSource();
            params.addValue("year", year);
            params.addValue("scpCodes", chunk);
            params.addValue("userId", userId);
            namedJdbcTemplate.update(SQL_INSERT_SUM, params);
        });
    }

    /**
     * [핵심 변경] TCP 스타일 적응형 배치 프로세서
     * 목표 시간(30초)에 맞춰 청크 사이즈를 동적으로 조절합니다.
     */
    private void processAdaptiveChunks(String taskName, List<String> totalItems, Consumer<List<String>> processor) {
        if (totalItems == null || totalItems.isEmpty()) return;

        int totalSize = totalItems.size();
        int currentIdx = 0;

        // 동적 변수
        int currentChunkSize = INITIAL_CHUNK_SIZE;
        long startTime = System.currentTimeMillis();

        log.info("[{}] Adaptive Batch 시작 (총 {}건, 초기 청크 {})", taskName, totalSize, currentChunkSize);

        while (currentIdx < totalSize) {
            List<CompletableFuture<Long>> futures = new ArrayList<>();
            int actualBatchCount = 0;

            // 1. 현재 청크 사이즈만큼 병렬 작업 제출
            for (int i = 0; i < CONCURRENT_LIMIT && currentIdx < totalSize; i++) {
                // 남은 개수와 현재 청크 사이즈 중 작은 값 선택
                int endIdx = Math.min(currentIdx + currentChunkSize, totalSize);
                List<String> chunk = totalItems.subList(currentIdx, endIdx);
                currentIdx = endIdx;
                actualBatchCount++;

                // 비동기 실행 및 시간 측정
                futures.add(CompletableFuture.supplyAsync(() -> {
                    long tStart = System.currentTimeMillis();
                    try {
                        processor.accept(chunk);
                    } catch (Exception e) {
                        log.error("!!! [ERROR] Chunk 실패: {}", e.getMessage());
                        throw new RuntimeException("Data inconsistency detected", e);
                    }
                    return System.currentTimeMillis() - tStart;
                }, dssBatchExecutor));
            }

            // 2. 병렬 그룹 완료 대기 (가장 느린 작업 기준)
            // TCP로 치면 RTT(Round Trip Time) 측정
            long maxDurationInGroup = 0;
            try {
                CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
                for (CompletableFuture<Long> f : futures) {
                    maxDurationInGroup = Math.max(maxDurationInGroup, f.get()); // 가장 오래 걸린 시간 추출
                }
            } catch (Exception e) {
                log.error("Future Get Error", e);
            }

            // 3. 청크 사이즈 재조정 (Congestion Control)
            int oldSize = currentChunkSize;
            currentChunkSize = calculateNextChunkSize(currentChunkSize, maxDurationInGroup);

            // 4. 진행 로그
            printAdaptiveProgress(taskName, currentIdx, totalSize, maxDurationInGroup, oldSize, currentChunkSize);
        }
    }

    /**
     * [알고리즘] 다음 청크 크기 계산
     * - 너무 빠르면: 크기 증가 (Additive Increase / Slow Start)
     * - 너무 느리면: 크기 감소 (Multiplicative Decrease)
     */
    private int calculateNextChunkSize(int currentSize, long executionTimeMs) {
        // 0으로 나누기 방지
        long safeTime = Math.max(executionTimeMs, 1);

        // 목표 대비 실제 비율 (예: 30초 목표인데 15초 걸림 -> ratio 2.0)
        double ratio = (double) TARGET_EXECUTION_TIME_MS / safeTime;

        int nextSize;

        if (executionTimeMs > TARGET_EXECUTION_TIME_MS) {
            // [혼잡 감지] 목표 시간 초과 -> 감소
            // 너무 급격히 줄이지 않고 비율대로 감소 (TCP Reno style)
            nextSize = (int) (currentSize * ratio);
        } else {
            // [여유 있음] 목표 시간 미만 -> 증가
            // 급격한 스파이크 방지를 위해 최대 1.5배까지만 성장 허용 (Capping)
            int proposedSize = (int) (currentSize * ratio);
            int growthLimit = (int) (currentSize * 1.5) + 1; // +1은 정체 방지

            nextSize = Math.min(proposedSize, growthLimit);
        }

        // Boundary 설정
        return Math.max(MIN_CHUNK_SIZE, Math.min(nextSize, MAX_CHUNK_SIZE));
    }

    private void printAdaptiveProgress(String taskName, int processed, int total, long duration, int oldSize, int newSize) {
        double percent = (double) processed / total * 100.0;
        String direction = (newSize > oldSize) ? "▲" : (newSize < oldSize) ? "▼" : "-";

        // [수정] Slf4j는 {}만 지원하므로, 숫자는 String.format으로 미리 변환해서 넘깁니다.
        String percentStr = String.format("%.1f", percent);
        String durationStr = String.format("%,d", duration);

        log.info("[{}] {}% ({}/{}) | 수행: {}ms | Chunk: {} -> {} {} | (목표 30s)",
                taskName, percentStr, processed, total,
                durationStr,
                oldSize, newSize, direction
        );
    }

    private String formatDuration(long millis) {
        Duration duration = Duration.ofMillis(millis);
        return String.format("%d분 %d초", duration.toMinutes(), duration.minusMinutes(duration.toMinutes()).getSeconds());
    }
}