package com.saas.voip.service;

import com.saas.admin.dto.VapiCostAnalyticsDTO;
import com.saas.shared.core.TenantContext;
import com.saas.tenant.entity.VapiCall;
import com.saas.tenant.entity.VapiCallCostRecord;
import com.saas.tenant.repository.VapiCallCostRecordRepository;
import com.saas.tenant.repository.VapiCallRepository;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * Vapi.ai Cost Tracking Service
 * 
 * Tracks and calculates costs for Vapi.ai calls
 * - Vapi pricing: $0.13-0.31 per minute
 * - Cost breakdown: transport, transcriber, model, voice
 * - Separate from AiApiCostRecord (different cost domain)
 * 
 * Architecture:
 * - Stores costs in VapiCallCostRecord (tenant database)
 * - Provides analytics per tenant
 * - Integrates with end-of-call-report webhook
 * - Multi-tenant routing via TenantContext
 */
@Service
@Slf4j
@RequiredArgsConstructor
public class VapiCostTrackingService {
    
    private final VapiCallCostRecordRepository costRecordRepository;
    private final VapiCallRepository callRepository;
    
    @Value("${vapi.cost.per-minute-min:0.13}")
    private BigDecimal costPerMinuteMin;
    
    @Value("${vapi.cost.per-minute-max:0.31}")
    private BigDecimal costPerMinuteMax;
    
    /**
     * Track call cost from Vapi webhook data
     */
    @Transactional
    public void trackCallCost(VapiCall call, Map<String, Object> costData) {
        try {
            log.info("💰 Tracking cost for Vapi call: {}", call.getVapiCallId());
            
            // Check if cost record already exists
            if (costRecordRepository.findByVapiCallId(call.getVapiCallId()).isPresent()) {
                log.info("ℹ️ Cost record already exists for call: {}", call.getVapiCallId());
                return;
            }
            
            // Extract total cost
            Object totalCostObj = costData.get("total");
            if (totalCostObj == null) {
                log.warn("⚠️ No total cost in cost data for call: {}", call.getVapiCallId());
                return;
            }
            
            BigDecimal totalCost = new BigDecimal(totalCostObj.toString());
            
            // Calculate duration in minutes
            BigDecimal durationMinutes = BigDecimal.ZERO;
            if (call.getDuration() != null && call.getDuration() > 0) {
                durationMinutes = new BigDecimal(call.getDuration())
                    .divide(new BigDecimal(60), 4, RoundingMode.HALF_UP);
            } else if (call.getStartTime() != null && call.getEndTime() != null) {
                long durationSeconds = ChronoUnit.SECONDS.between(call.getStartTime(), call.getEndTime());
                durationMinutes = new BigDecimal(durationSeconds)
                    .divide(new BigDecimal(60), 4, RoundingMode.HALF_UP);
            }
            
            // Calculate cost per minute
            BigDecimal costPerMinute = BigDecimal.ZERO;
            if (durationMinutes.compareTo(BigDecimal.ZERO) > 0) {
                costPerMinute = totalCost.divide(durationMinutes, 6, RoundingMode.HALF_UP);
            }
            
            // Create cost record
            VapiCallCostRecord costRecord = VapiCallCostRecord.builder()
                .vapiCallId(call.getVapiCallId())
                .provider("VAPI")
                .durationMinutes(durationMinutes)
                .costPerMinute(costPerMinute)
                .totalCost(totalCost)
                .currency("USD")
                .costBreakdown(costData)
                .callStartTime(call.getStartTime())
                .callEndTime(call.getEndTime())
                .assistantId(call.getAssistantId())
                .phoneNumber(call.getPhoneNumber())
                .customerNumber(call.getCustomerNumber())
                .build();
            
            costRecordRepository.save(costRecord);
            log.info("✅ Cost tracked: ${} for call {} ({} min @ ${}/min)", 
                totalCost, call.getVapiCallId(), durationMinutes, costPerMinute);
            
        } catch (Exception e) {
            log.error("❌ Error tracking Vapi call cost: {}", e.getMessage(), e);
        }
    }
    
    /**
     * Get cost analytics for a tenant
     */
    public VapiCostAnalyticsDTO getCostAnalytics(String tenantId, LocalDate startDate, LocalDate endDate) {
        TenantContext.setTenantId(tenantId);
        try {
            log.info("📊 Getting cost analytics for tenant: {} ({} to {})", tenantId, startDate, endDate);
            
            LocalDateTime startDateTime = startDate.atStartOfDay();
            LocalDateTime endDateTime = endDate.plusDays(1).atStartOfDay();
            
            // Get cost records
            List<VapiCallCostRecord> costRecords = costRecordRepository
                .findByCallStartTimeBetween(startDateTime, endDateTime);
            
            // Get calls
            List<VapiCall> calls = callRepository
                .findByStartTimeBetween(startDateTime, endDateTime);
            
            // Calculate totals
            BigDecimal totalCost = costRecords.stream()
                .map(VapiCallCostRecord::getTotalCost)
                .reduce(BigDecimal.ZERO, BigDecimal::add);
            
            Integer totalCalls = calls.size();
            
            BigDecimal averageCostPerCall = totalCalls > 0
                ? totalCost.divide(new BigDecimal(totalCalls), 6, RoundingMode.HALF_UP)
                : BigDecimal.ZERO;
            
            Integer totalDurationSeconds = calls.stream()
                .map(VapiCall::getDuration)
                .filter(d -> d != null)
                .reduce(0, Integer::sum);
            
            BigDecimal averageDurationSeconds = totalCalls > 0
                ? new BigDecimal(totalDurationSeconds).divide(new BigDecimal(totalCalls), 2, RoundingMode.HALF_UP)
                : BigDecimal.ZERO;
            
            BigDecimal costPerMinute = costRecords.stream()
                .map(VapiCallCostRecord::getCostPerMinute)
                .filter(c -> c != null && c.compareTo(BigDecimal.ZERO) > 0)
                .reduce(BigDecimal.ZERO, BigDecimal::add)
                .divide(new BigDecimal(costRecords.size() > 0 ? costRecords.size() : 1), 6, RoundingMode.HALF_UP);
            
            // Calculate daily costs
            List<VapiCostAnalyticsDTO.DailyCostDTO> dailyCosts = calculateDailyCosts(costRecords, calls);
            
            // Calculate costs by assistant
            Map<String, BigDecimal> costsByAssistant = calculateCostsByAssistant(costRecords);
            
            // Calculate calls by status
            Map<String, Integer> callsByStatus = calls.stream()
                .collect(Collectors.groupingBy(
                    call -> call.getStatus() != null ? call.getStatus() : "unknown",
                    Collectors.summingInt(c -> 1)
                ));
            
            return VapiCostAnalyticsDTO.builder()
                .tenantId(tenantId)
                .totalCost(totalCost)
                .totalCalls(totalCalls)
                .averageCostPerCall(averageCostPerCall)
                .totalDurationSeconds(totalDurationSeconds)
                .averageDurationSeconds(averageDurationSeconds)
                .costPerMinute(costPerMinute)
                .startDate(startDate)
                .endDate(endDate)
                .dailyCosts(dailyCosts)
                .costsByAssistant(costsByAssistant)
                .callsByStatus(callsByStatus)
                .build();
            
        } catch (Exception e) {
            log.error("❌ Error getting cost analytics: {}", e.getMessage(), e);
            throw new RuntimeException("Error getting cost analytics", e);
        } finally {
            TenantContext.clear();
        }
    }
    
    /**
     * Calculate daily costs
     */
    private List<VapiCostAnalyticsDTO.DailyCostDTO> calculateDailyCosts(
            List<VapiCallCostRecord> costRecords, List<VapiCall> calls) {
        
        Map<LocalDate, List<VapiCallCostRecord>> costsByDate = costRecords.stream()
            .filter(r -> r.getCreatedAt() != null)
            .collect(Collectors.groupingBy(r -> r.getCreatedAt().toLocalDate()));
        
        Map<LocalDate, List<VapiCall>> callsByDate = calls.stream()
            .filter(c -> c.getCreatedAt() != null)
            .collect(Collectors.groupingBy(c -> c.getCreatedAt().toLocalDate()));
        
        List<VapiCostAnalyticsDTO.DailyCostDTO> dailyCosts = new ArrayList<>();
        
        for (LocalDate date : costsByDate.keySet()) {
            List<VapiCallCostRecord> dayCosts = costsByDate.getOrDefault(date, List.of());
            List<VapiCall> dayCalls = callsByDate.getOrDefault(date, List.of());
            
            BigDecimal totalDayCost = dayCosts.stream()
                .map(VapiCallCostRecord::getTotalCost)
                .reduce(BigDecimal.ZERO, BigDecimal::add);
            
            Integer totalDayDuration = dayCalls.stream()
                .map(VapiCall::getDuration)
                .filter(d -> d != null)
                .reduce(0, Integer::sum);
            
            dailyCosts.add(VapiCostAnalyticsDTO.DailyCostDTO.builder()
                .date(date)
                .cost(totalDayCost)
                .calls(dayCalls.size())
                .durationSeconds(totalDayDuration)
                .build());
        }
        
        return dailyCosts.stream()
            .sorted((a, b) -> a.getDate().compareTo(b.getDate()))
            .collect(Collectors.toList());
    }
    
    /**
     * Calculate costs grouped by assistant
     */
    private Map<String, BigDecimal> calculateCostsByAssistant(List<VapiCallCostRecord> costRecords) {
        return costRecords.stream()
            .filter(r -> r.getAssistantId() != null)
            .collect(Collectors.groupingBy(
                VapiCallCostRecord::getAssistantId,
                Collectors.reducing(
                    BigDecimal.ZERO,
                    VapiCallCostRecord::getTotalCost,
                    BigDecimal::add
                )
            ));
    }
}
