Spring Boot 2.7 OAuth2 XML响应处理完整实现

1

Spring Boot 2.7 OAuth2 XML响应处理完整实现

1. 依赖配置 (pom.xml)

<dependencies>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-oauth2-client</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-jdbc</artifactId>
    </dependency>
    <dependency>
        <groupId>org.mybatis.spring.boot</groupId>
        <artifactId>mybatis-spring-boot-starter</artifactId>
        <version>2.3.1</version>
    </dependency>
    <dependency>
        <groupId>mysql</groupId>
        <artifactId>mysql-connector-java</artifactId>
        <scope>runtime</scope>
    </dependency>
    <!-- XML处理 -->
    <dependency>
        <groupId>com.fasterxml.jackson.dataformat</groupId>
        <artifactId>jackson-dataformat-xml</artifactId>
    </dependency>
</dependencies>

2. 数据库表结构

CREATE TABLE oauth2_tokens (
    id BIGINT AUTO_INCREMENT PRIMARY KEY,
    client_registration_id VARCHAR(100) NOT NULL,
    principal_name VARCHAR(255) NOT NULL,
    access_token TEXT NOT NULL,
    refresh_token TEXT,
    token_type VARCHAR(50) DEFAULT 'Bearer',
    expires_at TIMESTAMP NULL,
    scopes TEXT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    UNIQUE KEY uk_client_principal (client_registration_id, principal_name)
);

3. 实体类

package com.example.oauth2.entity;

import java.time.LocalDateTime;

public class OAuth2Token {
    private Long id;
    private String clientRegistrationId;
    private String principalName;
    private String accessToken;
    private String refreshToken;
    private String tokenType;
    private LocalDateTime expiresAt;
    private String scopes;
    private LocalDateTime createdAt;
    private LocalDateTime updatedAt;

    // 构造函数
    public OAuth2Token() {}

    public OAuth2Token(String clientRegistrationId, String principalName, 
                       String accessToken, String refreshToken, String tokenType,
                       LocalDateTime expiresAt, String scopes) {
        this.clientRegistrationId = clientRegistrationId;
        this.principalName = principalName;
        this.accessToken = accessToken;
        this.refreshToken = refreshToken;
        this.tokenType = tokenType;
        this.expiresAt = expiresAt;
        this.scopes = scopes;
    }

    // Getter 和 Setter 方法
    public Long getId() { return id; }
    public void setId(Long id) { this.id = id; }

    public String getClientRegistrationId() { return clientRegistrationId; }
    public void setClientRegistrationId(String clientRegistrationId) { 
        this.clientRegistrationId = clientRegistrationId; 
    }

    public String getPrincipalName() { return principalName; }
    public void setPrincipalName(String principalName) { 
        this.principalName = principalName; 
    }

    public String getAccessToken() { return accessToken; }
    public void setAccessToken(String accessToken) { 
        this.accessToken = accessToken; 
    }

    public String getRefreshToken() { return refreshToken; }
    public void setRefreshToken(String refreshToken) { 
        this.refreshToken = refreshToken; 
    }

    public String getTokenType() { return tokenType; }
    public void setTokenType(String tokenType) { 
        this.tokenType = tokenType; 
    }

    public LocalDateTime getExpiresAt() { return expiresAt; }
    public void setExpiresAt(LocalDateTime expiresAt) { 
        this.expiresAt = expiresAt; 
    }

    public String getScopes() { return scopes; }
    public void setScopes(String scopes) { 
        this.scopes = scopes; 
    }

    public LocalDateTime getCreatedAt() { return createdAt; }
    public void setCreatedAt(LocalDateTime createdAt) { 
        this.createdAt = createdAt; 
    }

    public LocalDateTime getUpdatedAt() { return updatedAt; }
    public void setUpdatedAt(LocalDateTime updatedAt) { 
        this.updatedAt = updatedAt; 
    }
}

4. MyBatis Mapper

package com.example.oauth2.mapper;

import com.example.oauth2.entity.OAuth2Token;
import org.apache.ibatis.annotations.*;

import java.util.Optional;

@Mapper
public interface OAuth2TokenMapper {

    @Insert("INSERT INTO oauth2_tokens (client_registration_id, principal_name, " +
            "access_token, refresh_token, token_type, expires_at, scopes) " +
            "VALUES (#{clientRegistrationId}, #{principalName}, #{accessToken}, " +
            "#{refreshToken}, #{tokenType}, #{expiresAt}, #{scopes})")
    @Options(useGeneratedKeys = true, keyProperty = "id")
    int insert(OAuth2Token token);

    @Update("UPDATE oauth2_tokens SET access_token = #{accessToken}, " +
            "refresh_token = #{refreshToken}, token_type = #{tokenType}, " +
            "expires_at = #{expiresAt}, scopes = #{scopes} " +
            "WHERE client_registration_id = #{clientRegistrationId} " +
            "AND principal_name = #{principalName}")
    int update(OAuth2Token token);

    @Select("SELECT * FROM oauth2_tokens WHERE client_registration_id = #{clientRegistrationId} " +
            "AND principal_name = #{principalName}")
    Optional<OAuth2Token> findByClientRegistrationIdAndPrincipalName(
            @Param("clientRegistrationId") String clientRegistrationId,
            @Param("principalName") String principalName);

    @Delete("DELETE FROM oauth2_tokens WHERE client_registration_id = #{clientRegistrationId} " +
            "AND principal_name = #{principalName}")
    int deleteByClientRegistrationIdAndPrincipalName(
            @Param("clientRegistrationId") String clientRegistrationId,
            @Param("principalName") String principalName);
}

5. XML响应解析器

package com.example.oauth2.converter;

import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlProperty;
import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlRootElement;

@JacksonXmlRootElement(localName = "oauth")
public class OAuth2AccessTokenXmlResponse {
    
    @JacksonXmlProperty(localName = "access_token")
    private String accessToken;
    
    @JacksonXmlProperty(localName = "token_type")
    private String tokenType;
    
    @JacksonXmlProperty(localName = "expires_in")
    private Long expiresIn;
    
    @JacksonXmlProperty(localName = "refresh_token")
    private String refreshToken;
    
    @JacksonXmlProperty(localName = "scope")
    private String scope;

    // 构造函数
    public OAuth2AccessTokenXmlResponse() {}

    // Getter 和 Setter
    public String getAccessToken() { return accessToken; }
    public void setAccessToken(String accessToken) { this.accessToken = accessToken; }

    public String getTokenType() { return tokenType; }
    public void setTokenType(String tokenType) { this.tokenType = tokenType; }

    public Long getExpiresIn() { return expiresIn; }
    public void setExpiresIn(Long expiresIn) { this.expiresIn = expiresIn; }

    public String getRefreshToken() { return refreshToken; }
    public void setRefreshToken(String refreshToken) { this.refreshToken = refreshToken; }

    public String getScope() { return scope; }
    public void setScope(String scope) { this.scope = scope; }
}

6. 自定义 OAuth2 Access Token 响应客户端

package com.example.oauth2.converter;

import com.fasterxml.jackson.dataformat.xml.XmlMapper;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.web.client.RestTemplate;

import java.time.Instant;
import java.util.Set;

public class XmlOAuth2AccessTokenResponseClient 
        implements OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {

    private final RestTemplate restTemplate = new RestTemplate();
    private final XmlMapper xmlMapper = new XmlMapper();
    private final Converter<OAuth2AuthorizationCodeGrantRequest, RequestEntity<?>> requestEntityConverter;

    public XmlOAuth2AccessTokenResponseClient() {
        this.requestEntityConverter = new OAuth2AuthorizationCodeGrantRequestEntityConverter();
    }

    @Override
    public OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) {
        RequestEntity<?> request = this.requestEntityConverter.convert(authorizationGrantRequest);
        
        ResponseEntity<String> response = this.restTemplate.exchange(request, String.class);
        
        try {
            OAuth2AccessTokenXmlResponse xmlResponse = xmlMapper.readValue(
                response.getBody(), OAuth2AccessTokenXmlResponse.class);
            
            return convertToOAuth2AccessTokenResponse(xmlResponse);
        } catch (Exception e) {
            throw new RuntimeException("Failed to parse XML token response", e);
        }
    }

    private OAuth2AccessTokenResponse convertToOAuth2AccessTokenResponse(OAuth2AccessTokenXmlResponse xmlResponse) {
        OAuth2AccessToken.TokenType tokenType = OAuth2AccessToken.TokenType.BEARER;
        if (xmlResponse.getTokenType() != null) {
            if ("bearer".equalsIgnoreCase(xmlResponse.getTokenType())) {
                tokenType = OAuth2AccessToken.TokenType.BEARER;
            }
        }

        Instant expiresAt = null;
        if (xmlResponse.getExpiresIn() != null) {
            expiresAt = Instant.now().plusSeconds(xmlResponse.getExpiresIn());
        }

        Set<String> scopes = null;
        if (xmlResponse.getScope() != null) {
            scopes = Set.of(xmlResponse.getScope().split("\\s+"));
        }

        return OAuth2AccessTokenResponse.withToken(xmlResponse.getAccessToken())
                .tokenType(tokenType)
                .expiresIn(xmlResponse.getExpiresIn())
                .scopes(scopes)
                .refreshToken(xmlResponse.getRefreshToken())
                .build();
    }
}

7. 自定义 OAuth2 Token 持久化服务

package com.example.oauth2.service;

import com.example.oauth2.entity.OAuth2Token;
import com.example.oauth2.mapper.OAuth2TokenMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.stereotype.Service;

import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Optional;
import java.util.stream.Collectors;

@Service
public class DatabaseOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {

    @Autowired
    private OAuth2TokenMapper tokenMapper;

    @Override
    public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(
            String clientRegistrationId, String principalName) {
        
        Optional<OAuth2Token> tokenOpt = tokenMapper.findByClientRegistrationIdAndPrincipalName(
                clientRegistrationId, principalName);
        
        if (tokenOpt.isPresent()) {
            OAuth2Token token = tokenOpt.get();
            
            // 这里需要重新构建 ClientRegistration,实际项目中应该从配置中获取
            ClientRegistration clientRegistration = getClientRegistration(clientRegistrationId);
            
            OAuth2AccessToken accessToken = new OAuth2AccessToken(
                    OAuth2AccessToken.TokenType.valueOf(token.getTokenType().toUpperCase()),
                    token.getAccessToken(),
                    token.getCreatedAt().atZone(ZoneId.systemDefault()).toInstant(),
                    token.getExpiresAt() != null ? 
                        token.getExpiresAt().atZone(ZoneId.systemDefault()).toInstant() : null,
                    token.getScopes() != null ? 
                        Set.of(token.getScopes().split(",")) : null
            );

            OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
                    clientRegistration, principalName, accessToken);
            
            return (T) authorizedClient;
        }
        
        return null;
    }

    @Override
    public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, 
                                   Authentication principal) {
        String clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId();
        String principalName = principal.getName();
        OAuth2AccessToken accessToken = authorizedClient.getAccessToken();

        LocalDateTime expiresAt = accessToken.getExpiresAt() != null ?
                LocalDateTime.ofInstant(accessToken.getExpiresAt(), ZoneId.systemDefault()) : null;

        String scopes = accessToken.getScopes() != null ?
                accessToken.getScopes().stream().collect(Collectors.joining(",")) : null;

        OAuth2Token token = new OAuth2Token(
                clientRegistrationId,
                principalName,
                accessToken.getTokenValue(),
                authorizedClient.getRefreshToken() != null ? 
                    authorizedClient.getRefreshToken().getTokenValue() : null,
                accessToken.getTokenType().getValue(),
                expiresAt,
                scopes
        );

        // 先尝试更新,如果更新失败则插入
        int updated = tokenMapper.update(token);
        if (updated == 0) {
            tokenMapper.insert(token);
        }
    }

    @Override
    public void removeAuthorizedClient(String clientRegistrationId, String principalName) {
        tokenMapper.deleteByClientRegistrationIdAndPrincipalName(clientRegistrationId, principalName);
    }

    // 这个方法需要根据实际情况实现,从Spring Security配置中获取ClientRegistration
    private ClientRegistration getClientRegistration(String registrationId) {
        // 实际实现中应该注入 ClientRegistrationRepository
        throw new UnsupportedOperationException("需要实现 ClientRegistration 获取逻辑");
    }
}

8. OAuth2 配置类

package com.example.oauth2.config;

import com.example.oauth2.converter.XmlOAuth2AccessTokenResponseClient;
import com.example.oauth2.service.DatabaseOAuth2AuthorizedClientService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.web.SecurityFilterChain;

@Configuration
@EnableWebSecurity
public class OAuth2Config {

    @Autowired
    private DatabaseOAuth2AuthorizedClientService authorizedClientService;

    @Bean
    public OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> 
            xmlAccessTokenResponseClient() {
        return new XmlOAuth2AccessTokenResponseClient();
    }

    @Bean
    public OAuth2AuthorizedClientRepository authorizedClientRepository() {
        return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
    }

    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        http
            .authorizeHttpRequests(authz -> authz
                .requestMatchers("/", "/login**").permitAll()
                .anyRequest().authenticated()
            )
            .oauth2Login(oauth2 -> oauth2
                .tokenEndpoint(token -> token
                    .accessTokenResponseClient(xmlAccessTokenResponseClient())
                )
                .authorizedClientRepository(authorizedClientRepository())
            );
        
        return http.build();
    }
}

9. 应用配置 (application.yml)

spring:
  security:
    oauth2:
      client:
        registration:
          third-party:
            client-id: your-client-id
            client-secret: your-client-secret
            authorization-grant-type: authorization_code
            redirect-uri: "http://localhost:8080/login/oauth2/code/third-party"
            scope: read,write
        provider:
          third-party:
            authorization-uri: https://third-party.example.com/oauth2/authorize
            token-uri: https://third-party.example.com/oauth2/token
            user-info-uri: https://third-party.example.com/oauth2/userinfo

  datasource:
    url: jdbc:mysql://localhost:3306/oauth2_db
    username: your-username
    password: your-password
    driver-class-name: com.mysql.cj.jdbc.Driver

mybatis:
  configuration:
    map-underscore-to-camel-case: true
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl

10. 测试控制器

package com.example.oauth2.controller;

import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class OAuth2Controller {

    @GetMapping("/")
    public String home() {
        return "欢迎!请点击 <a href='/oauth2/authorization/third-party'>登录</a>";
    }

    @GetMapping("/user")
    public String user(@RegisteredOAuth2AuthorizedClient("third-party") OAuth2AuthorizedClient authorizedClient,
                      @AuthenticationPrincipal OAuth2User oauth2User) {
        
        return String.format("用户: %s<br/>Access Token: %s", 
                           oauth2User.getName(),
                           authorizedClient.getAccessToken().getTokenValue());
    }
}

使用说明

  1. 数据库配置: 创建相应的数据库表结构
  2. 依赖配置: 添加所需的 Maven 依赖
  3. OAuth2 配置: 在 application.yml 中配置第三方 OAuth2 提供商信息
  4. XML 解析: OAuth2AccessTokenXmlResponse 类处理 XML 格式的令牌响应
  5. 持久化: DatabaseOAuth2AuthorizedClientService 负责将令牌存储到数据库
  6. 自定义客户端: XmlOAuth2AccessTokenResponseClient 处理 XML 格式的令牌端点响应

这个实现提供了完整的 OAuth2 集成方案,支持 XML 格式响应解析和数据库持久化。