package com.saas.voip.service.ai;

import com.saas.shared.dto.AiCostCalculationResult;
import com.saas.shared.service.ai.AiCostCalculator;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.HashMap;
import java.util.Map;

/**
 * Google Gemini Cost Calculator
 * 
 * Pricing (as of Dec 2024):
 * 
 * Gemini 2.0 Flash (Multimodal Live API - Real-time):
 * - Audio input:  $0.30 per 1M tokens  ($0.000000300 per token)
 * - Audio output: $1.20 per 1M tokens  ($0.000001200 per token)
 * - Text input:   $0.075 per 1M tokens ($0.000000075 per token)
 * - Text output:  $0.30 per 1M tokens  ($0.000000300 per token)
 * 
 * Gemini 1.5 Pro:
 * - Input tokens:  $1.25 per 1M tokens ($0.000001250 per token)
 * - Output tokens: $5.00 per 1M tokens ($0.000005000 per token)
 * 
 * Gemini 1.5 Flash:
 * - Input tokens:  $0.075 per 1M tokens ($0.000000075 per token)
 * - Output tokens: $0.30 per 1M tokens  ($0.000000300 per token)
 * 
 * Source: https://ai.google.dev/pricing
 * 
 * Token Calculation:
 * - Audio tokens are based on duration and sample rate
 * - Text tokens similar to other LLMs (~4 chars per token)
 * - Usage data comes from Gemini API response
 * 
 * Clean Architecture:
 * - Implements AiCostCalculator interface
 * - Single Responsibility: Calculate Google Gemini API costs
 * - No persistence logic (handled by AiCostTrackingService)
 */
@Service
@Slf4j
public class GeminiCostCalculator implements AiCostCalculator {
    
    // Gemini 2.0 Flash (Multimodal Live API - Real-time) - Default for voice
    private static final BigDecimal GEMINI_2_FLASH_AUDIO_INPUT_PRICE = new BigDecimal("0.000000300");  // $0.30/1M
    private static final BigDecimal GEMINI_2_FLASH_AUDIO_OUTPUT_PRICE = new BigDecimal("0.000001200"); // $1.20/1M
    private static final BigDecimal GEMINI_2_FLASH_TEXT_INPUT_PRICE = new BigDecimal("0.000000075");   // $0.075/1M
    private static final BigDecimal GEMINI_2_FLASH_TEXT_OUTPUT_PRICE = new BigDecimal("0.000000300");  // $0.30/1M
    
    // Gemini 1.5 Pro
    private static final BigDecimal GEMINI_15_PRO_INPUT_PRICE = new BigDecimal("0.000001250");  // $1.25/1M
    private static final BigDecimal GEMINI_15_PRO_OUTPUT_PRICE = new BigDecimal("0.000005000"); // $5.00/1M
    
    // Gemini 1.5 Flash
    private static final BigDecimal GEMINI_15_FLASH_INPUT_PRICE = new BigDecimal("0.000000075");  // $0.075/1M
    private static final BigDecimal GEMINI_15_FLASH_OUTPUT_PRICE = new BigDecimal("0.000000300"); // $0.30/1M
    
    private static final String PROVIDER_NAME = "GEMINI";
    private static final String COST_TYPE_REALTIME = "REALTIME_API";
    private static final String COST_TYPE_LLM = "LLM";
    
    @Override
    public AiCostCalculationResult calculateCost(Map<String, Object> usageMetrics) {
        log.debug("💰 [Gemini] Calculating cost from usage metrics: {}", usageMetrics);
        
        if (usageMetrics == null || usageMetrics.isEmpty()) {
            log.warn("⚠️ [Gemini] Empty usage metrics provided");
            return createZeroCostResult();
        }
        
        try {
            String model = (String) usageMetrics.getOrDefault("model", "gemini-2.0-flash");
            boolean isRealtimeApi = usageMetrics.containsKey("audio_input_tokens") || 
                                   usageMetrics.containsKey("audio_output_tokens");
            
            BigDecimal totalCost;
            Map<String, Object> detailedUsage = new HashMap<>();
            
            if (isRealtimeApi) {
                // Multimodal Live API (Real-time voice)
                totalCost = calculateRealtimeCost(usageMetrics, model, detailedUsage);
            } else {
                // Standard text LLM API
                totalCost = calculateTextCost(usageMetrics, model, detailedUsage);
            }
            
            log.info("💰 [Gemini] Total cost - Model: {}, Cost: ${}", model, totalCost);
            
            // Add model to detailed usage
            detailedUsage.put("model", model);
            
            // Add optional metadata
            Map<String, Object> metadata = new HashMap<>();
            if (usageMetrics.containsKey("session_id")) {
                metadata.put("session_id", usageMetrics.get("session_id"));
            }
            if (usageMetrics.containsKey("conversation_id")) {
                metadata.put("conversation_id", usageMetrics.get("conversation_id"));
            }
            
            String costType = isRealtimeApi ? COST_TYPE_REALTIME : COST_TYPE_LLM;
            
            return AiCostCalculationResult.builder()
                    .aiProvider(PROVIDER_NAME)
                    .costType(costType)
                    .cost(totalCost)
                    .currency("USD")
                    .usageDetails(detailedUsage)
                    .metadata(metadata.isEmpty() ? null : metadata)
                    .build();
            
        } catch (Exception e) {
            log.error("❌ [Gemini] Error calculating cost", e);
            return createZeroCostResult();
        }
    }
    
    /**
     * Calculate cost for Gemini Multimodal Live API (Real-time voice)
     */
    private BigDecimal calculateRealtimeCost(Map<String, Object> usageMetrics, String model, Map<String, Object> detailedUsage) {
        long audioInputTokens = extractLong(usageMetrics, "audio_input_tokens");
        long audioOutputTokens = extractLong(usageMetrics, "audio_output_tokens");
        long textInputTokens = extractLong(usageMetrics, "text_input_tokens");
        long textOutputTokens = extractLong(usageMetrics, "text_output_tokens");
        
        log.info("💰 [Gemini] Real-time usage - Audio In: {}, Audio Out: {}, Text In: {}, Text Out: {}", 
                audioInputTokens, audioOutputTokens, textInputTokens, textOutputTokens);
        
        // Currently only Gemini 2.0 Flash supports real-time
        BigDecimal audioInputCost = BigDecimal.valueOf(audioInputTokens)
                .multiply(GEMINI_2_FLASH_AUDIO_INPUT_PRICE)
                .setScale(6, RoundingMode.HALF_UP);
        
        BigDecimal audioOutputCost = BigDecimal.valueOf(audioOutputTokens)
                .multiply(GEMINI_2_FLASH_AUDIO_OUTPUT_PRICE)
                .setScale(6, RoundingMode.HALF_UP);
        
        BigDecimal textInputCost = BigDecimal.valueOf(textInputTokens)
                .multiply(GEMINI_2_FLASH_TEXT_INPUT_PRICE)
                .setScale(6, RoundingMode.HALF_UP);
        
        BigDecimal textOutputCost = BigDecimal.valueOf(textOutputTokens)
                .multiply(GEMINI_2_FLASH_TEXT_OUTPUT_PRICE)
                .setScale(6, RoundingMode.HALF_UP);
        
        BigDecimal totalCost = audioInputCost.add(audioOutputCost).add(textInputCost).add(textOutputCost);
        
        log.info("💰 [Gemini] Cost breakdown - Audio In: ${}, Audio Out: ${}, Text In: ${}, Text Out: ${}, Total: ${}", 
                audioInputCost, audioOutputCost, textInputCost, textOutputCost, totalCost);
        
        // Store detailed usage
        detailedUsage.put("audio_input_tokens", audioInputTokens);
        detailedUsage.put("audio_output_tokens", audioOutputTokens);
        detailedUsage.put("text_input_tokens", textInputTokens);
        detailedUsage.put("text_output_tokens", textOutputTokens);
        detailedUsage.put("audio_input_cost", audioInputCost);
        detailedUsage.put("audio_output_cost", audioOutputCost);
        detailedUsage.put("text_input_cost", textInputCost);
        detailedUsage.put("text_output_cost", textOutputCost);
        
        return totalCost;
    }
    
    /**
     * Calculate cost for standard Gemini text LLM API
     */
    private BigDecimal calculateTextCost(Map<String, Object> usageMetrics, String model, Map<String, Object> detailedUsage) {
        long inputTokens = extractLong(usageMetrics, "input_tokens");
        long outputTokens = extractLong(usageMetrics, "output_tokens");
        
        log.info("💰 [Gemini] Text usage - Model: {}, Input: {}, Output: {}", 
                model, inputTokens, outputTokens);
        
        // Determine pricing based on model
        ModelPricing pricing = getPricingForModel(model);
        
        BigDecimal inputCost = BigDecimal.valueOf(inputTokens)
                .multiply(pricing.inputPrice)
                .setScale(6, RoundingMode.HALF_UP);
        
        BigDecimal outputCost = BigDecimal.valueOf(outputTokens)
                .multiply(pricing.outputPrice)
                .setScale(6, RoundingMode.HALF_UP);
        
        BigDecimal totalCost = inputCost.add(outputCost);
        
        log.info("💰 [Gemini] Cost breakdown - Input: ${}, Output: ${}, Total: ${}", 
                inputCost, outputCost, totalCost);
        
        // Store detailed usage
        detailedUsage.put("input_tokens", inputTokens);
        detailedUsage.put("output_tokens", outputTokens);
        detailedUsage.put("input_cost", inputCost);
        detailedUsage.put("output_cost", outputCost);
        
        return totalCost;
    }
    
    @Override
    public String getProviderName() {
        return PROVIDER_NAME;
    }
    
    @Override
    public String[] getSupportedCostTypes() {
        return new String[]{COST_TYPE_REALTIME, COST_TYPE_LLM};
    }
    
    /**
     * Get pricing for a specific Gemini model
     */
    private ModelPricing getPricingForModel(String model) {
        if (model == null) {
            return new ModelPricing(GEMINI_2_FLASH_TEXT_INPUT_PRICE, GEMINI_2_FLASH_TEXT_OUTPUT_PRICE);
        }
        
        String normalizedModel = model.toLowerCase();
        
        // Gemini 2.0 Flash
        if (normalizedModel.contains("2.0-flash") || normalizedModel.contains("gemini-2")) {
            return new ModelPricing(GEMINI_2_FLASH_TEXT_INPUT_PRICE, GEMINI_2_FLASH_TEXT_OUTPUT_PRICE);
        }
        
        // Gemini 1.5 Pro
        if (normalizedModel.contains("1.5-pro")) {
            return new ModelPricing(GEMINI_15_PRO_INPUT_PRICE, GEMINI_15_PRO_OUTPUT_PRICE);
        }
        
        // Gemini 1.5 Flash
        if (normalizedModel.contains("1.5-flash")) {
            return new ModelPricing(GEMINI_15_FLASH_INPUT_PRICE, GEMINI_15_FLASH_OUTPUT_PRICE);
        }
        
        // Default to 2.0 Flash pricing
        log.warn("⚠️ [Gemini] Unknown model '{}', using Gemini 2.0 Flash pricing", model);
        return new ModelPricing(GEMINI_2_FLASH_TEXT_INPUT_PRICE, GEMINI_2_FLASH_TEXT_OUTPUT_PRICE);
    }
    
    /**
     * Extract long value from usage metrics
     */
    private long extractLong(Map<String, Object> metrics, String key) {
        Object value = metrics.get(key);
        if (value == null) {
            return 0L;
        }
        if (value instanceof Integer) {
            return ((Integer) value).longValue();
        }
        if (value instanceof Long) {
            return (Long) value;
        }
        if (value instanceof String) {
            try {
                return Long.parseLong((String) value);
            } catch (NumberFormatException e) {
                log.warn("⚠️ [Gemini] Cannot parse {} as long: {}", key, value);
                return 0L;
            }
        }
        log.warn("⚠️ [Gemini] Unexpected type for {}: {}", key, value.getClass());
        return 0L;
    }
    
    /**
     * Create a zero-cost result for error cases
     */
    private AiCostCalculationResult createZeroCostResult() {
        return AiCostCalculationResult.builder()
                .aiProvider(PROVIDER_NAME)
                .costType(COST_TYPE_LLM)
                .cost(BigDecimal.ZERO)
                .currency("USD")
                .usageDetails(new HashMap<>())
                .build();
    }
    
    /**
     * Helper class to hold model pricing
     */
    private static class ModelPricing {
        final BigDecimal inputPrice;
        final BigDecimal outputPrice;
        
        ModelPricing(BigDecimal inputPrice, BigDecimal outputPrice) {
            this.inputPrice = inputPrice;
            this.outputPrice = outputPrice;
        }
    }
}
