package com.saas.voip.handler;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.saas.voip.factory.VoiceAiSessionFactory;
import com.saas.voip.service.OpenAIRealtimeService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Component
@Slf4j
@RequiredArgsConstructor
public class TwilioMediaStreamHandler extends TextWebSocketHandler {

    private final VoiceAiSessionFactory voiceAiSessionFactory;
    private final OpenAIRealtimeService openAIRealtimeService;
    private final ObjectMapper objectMapper = new ObjectMapper();
    private final Map<String, String> sessionStreamIds = new ConcurrentHashMap<>();
    private final Map<String, AiSessionHandler> sessionHandlers = new ConcurrentHashMap<>();
    private final Map<String, Map<String, String>> sessionCallData = new ConcurrentHashMap<>();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        log.info("=== TWILIO WEBSOCKET CONNECTED === Session ID: {}", session.getId());
        log.info("Remote address: {}", session.getRemoteAddress());
        log.info("URI: {}", session.getUri());
        
        AiSessionHandler handler = voiceAiSessionFactory.createHandler();
        sessionHandlers.put(session.getId(), handler);
        sessionCallData.put(session.getId(), new HashMap<>());
        
        log.info("Created AI session handler: {}", handler.getClass().getSimpleName());
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) {
        try {
            JsonNode data = objectMapper.readTree(message.getPayload());
            String event = data.get("event").asText();

            switch (event) {
                case "media" -> handleMediaEvent(session, data);
                case "start" -> handleStartEvent(session, data);
                case "stop" -> handleStopEvent(session, data);
                default -> log.debug("Received non-media event: {}", event);
            }
        } catch (Exception e) {
            log.error("Error processing Twilio message", e);
        }
    }

    private void handleMediaEvent(WebSocketSession session, JsonNode data) {
        try {
            AiSessionHandler handler = sessionHandlers.get(session.getId());
            if (handler != null) {
                handler.onMediaFrame(session, data.toString());
            }
        } catch (Exception e) {
            log.error("Error handling media event", e);
        }
    }

    private void handleStartEvent(WebSocketSession session, JsonNode data) {
        try {
            JsonNode startNode = data.get("start");
            String streamSid = startNode.get("streamSid").asText();
            String callSid = startNode.has("callSid") ? startNode.get("callSid").asText() : null;
            
            JsonNode customParameters = startNode.has("customParameters") ? 
                startNode.get("customParameters") : null;
            
            String fromNumber = customParameters != null && customParameters.has("From") ?
                customParameters.get("From").asText() : null;
            String toNumber = customParameters != null && customParameters.has("To") ?
                customParameters.get("To").asText() : null;
            
            sessionStreamIds.put(session.getId(), streamSid);
            
            Map<String, String> callData = sessionCallData.get(session.getId());
            if (callData != null) {
                callData.put("callSid", callSid);
                callData.put("from", fromNumber);
                callData.put("to", toNumber);
            }
            
            log.info("📞 Incoming stream started: streamSid={}, callSid={}, from={}, to={}, session={}", 
                streamSid, callSid, fromNumber, toNumber, session.getId());
            
            AiSessionHandler handler = sessionHandlers.get(session.getId());
            if (handler != null) {
                // For backward compatibility with OpenAI service
                if (handler instanceof OpenAiSessionHandler) {
                    if (callSid != null) {
                        openAIRealtimeService.setCallSid(session.getId(), callSid);
                    }
                }
                
                // Call onClientConnect for all handlers with streamSid
                if (callSid != null) {
                    log.info("🚀 Calling onClientConnect for handler: {}", handler.getClass().getSimpleName());
                    handler.onClientConnect(session, streamSid, callSid, fromNumber, toNumber);
                } else {
                    log.warn("⚠️ Cannot call onClientConnect: callSid is null");
                }
            } else {
                log.error("❌ No handler found for session: {}", session.getId());
            }
        } catch (Exception e) {
            log.error("Error handling start event", e);
        }
    }

    private void handleStopEvent(WebSocketSession session, JsonNode data) {
        log.info("Stream stopped for session: {}", session.getId());
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        log.info("=== TWILIO WEBSOCKET DISCONNECTED === Session: {}, status: {}", session.getId(), status);
        
        try {
            AiSessionHandler handler = sessionHandlers.remove(session.getId());
            if (handler != null) {
                handler.onClose(session);
            }
        } catch (Exception e) {
            log.error("Error closing AI session handler", e);
        }
        
        sessionStreamIds.remove(session.getId());
        sessionCallData.remove(session.getId());
        openAIRealtimeService.disconnectFromOpenAI(session);
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) {
        log.error("=== WEBSOCKET TRANSPORT ERROR === Session: {}", session.getId(), exception);
    }

    public String getStreamSid(String sessionId) {
        return sessionStreamIds.get(sessionId);
    }
}
