package org.springframework.security.config.annotation.web.socket;

import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Import;
import org.springframework.core.annotation.Order;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.authorization.ObservationAuthorizationManager;
import org.springframework.security.authorization.SpringAuthorizationEventPublisher;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor;
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.util.Assert;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;

@Order(-2147483548)
@Import({MessageMatcherAuthorizationManagerConfiguration.class})
/* loaded from: input_file:WEB-INF/lib/spring-security-config-6.2.0.jar:org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.class */
final class WebSocketMessageBrokerSecurityConfiguration implements WebSocketMessageBrokerConfigurer, SmartInitializingSingleton {
    private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
    private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
    private MessageMatcherDelegatingAuthorizationManager b;
    private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager.builder().anyMessage().authenticated().build();
    private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder.getContextHolderStrategy();
    private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
    private ChannelInterceptor csrfChannelInterceptor = new XorCsrfChannelInterceptor();
    private AuthorizationManager<Message<?>> authorizationManager = ANY_MESSAGE_AUTHENTICATED;
    private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
    private ApplicationContext context;

    WebSocketMessageBrokerSecurityConfiguration(ApplicationContext applicationContext) {
        this.context = applicationContext;
    }

    public void addArgumentResolvers(List<HandlerMethodArgumentResolver> list) {
        AuthenticationPrincipalArgumentResolver authenticationPrincipalArgumentResolver = new AuthenticationPrincipalArgumentResolver();
        authenticationPrincipalArgumentResolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
        list.add(authenticationPrincipalArgumentResolver);
    }

    public void configureClientInboundChannel(ChannelRegistration channelRegistration) {
        ChannelInterceptor channelInterceptor = (ChannelInterceptor) getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME, ChannelInterceptor.class);
        if (channelInterceptor != null) {
            this.csrfChannelInterceptor = channelInterceptor;
        }
        AuthorizationManager<Message<?>> authorizationManager = this.authorizationManager;
        if (!this.observationRegistry.isNoop()) {
            authorizationManager = new ObservationAuthorizationManager(this.observationRegistry, authorizationManager);
        }
        ChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor(authorizationManager);
        authorizationChannelInterceptor.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
        authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
        this.securityContextChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
        channelRegistration.interceptors(new ChannelInterceptor[]{this.securityContextChannelInterceptor, this.csrfChannelInterceptor, authorizationChannelInterceptor});
    }

    @Autowired(required = false)
    void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
        Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
        this.securityContextHolderStrategy = securityContextHolderStrategy;
    }

    @Autowired(required = false)
    void setAuthorizationManager(AuthorizationManager<Message<?>> authorizationManager) {
        this.authorizationManager = authorizationManager;
    }

    @Autowired(required = false)
    void setObservationRegistry(ObservationRegistry observationRegistry) {
        this.observationRegistry = observationRegistry;
    }

    @Override // org.springframework.beans.factory.SmartInitializingSingleton
    public void afterSingletonsInstantiated() {
        SimpleUrlHandlerMapping simpleUrlHandlerMapping = (SimpleUrlHandlerMapping) getBeanOrNull(SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME, SimpleUrlHandlerMapping.class);
        if (simpleUrlHandlerMapping == null) {
            return;
        }
        configureCsrf(simpleUrlHandlerMapping);
    }

    private <T> T getBeanOrNull(String str, Class<T> cls) {
        return this.context.getBeansOfType(cls).get(str);
    }

    private void configureCsrf(SimpleUrlHandlerMapping simpleUrlHandlerMapping) {
        for (Object obj : simpleUrlHandlerMapping.getHandlerMap().values()) {
            if (obj instanceof SockJsHttpRequestHandler) {
                setHandshakeInterceptors((SockJsHttpRequestHandler) obj);
            } else {
                if (!(obj instanceof WebSocketHttpRequestHandler)) {
                    throw new IllegalStateException("Bean stompWebSocketHandlerMapping is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + obj);
                }
                setHandshakeInterceptors((WebSocketHttpRequestHandler) obj);
            }
        }
    }

    private void setHandshakeInterceptors(SockJsHttpRequestHandler sockJsHttpRequestHandler) {
        TransportHandlingSockJsService sockJsService = sockJsHttpRequestHandler.getSockJsService();
        Assert.state(sockJsService instanceof TransportHandlingSockJsService, (Supplier<String>) () -> {
            return "sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService;
        });
        TransportHandlingSockJsService transportHandlingSockJsService = sockJsService;
        List handshakeInterceptors = transportHandlingSockJsService.getHandshakeInterceptors();
        ArrayList arrayList = new ArrayList(handshakeInterceptors.size() + 1);
        arrayList.add(new CsrfTokenHandshakeInterceptor());
        arrayList.addAll(handshakeInterceptors);
        transportHandlingSockJsService.setHandshakeInterceptors(arrayList);
    }

    private void setHandshakeInterceptors(WebSocketHttpRequestHandler webSocketHttpRequestHandler) {
        List handshakeInterceptors = webSocketHttpRequestHandler.getHandshakeInterceptors();
        ArrayList arrayList = new ArrayList(handshakeInterceptors.size() + 1);
        arrayList.add(new CsrfTokenHandshakeInterceptor());
        arrayList.addAll(handshakeInterceptors);
        webSocketHttpRequestHandler.setHandshakeInterceptors(arrayList);
    }
}
