Can i append some information in oauth/check_token

2019-01-14 14:58发布

问题:

Preface

I am working on an OAuth application for security between two servers. I have an OAuth Server and a Resource Server. The Resource Server has a single .war deployed that contains 4 APIs.

Single Responsibility

  1. The OAuth server has to validate a the access token that was passed by an API (1 of the 4) from that same .war.
  2. The OAuth server has to keep a hit count for a particular accessToken for a particular API. If the hit count exceeds the configured hits the OAuth server would throw a 403: Forbidden.
  3. Every API in the .war must first validate the accessToken from the OAuth server and if it's validated, then proceed to provide the response.

What I've done:

If a .war has a single API then I can simply make the two servers communicate using a webHook, below is the code that does it.

On the Resource Server Side:

My urls for different APIs are:

  • localhost:8080/API/API1
  • localhost:8080/API/API2

Below code routes any request if they have /API/anything towards the spring security filters

<http pattern="/API/**" create-session="never" authentication-manager-ref="authenticationManager" entry-point-ref="oauthAuthenticationEntryPoint" xmlns="http://www.springframework.org/schema/security">
        <anonymous enabled="false" />        
        <intercept-url pattern="/places/**" method="GET" access="IS_AUTHENTICATED_FULLY" />
        <custom-filter ref="resourceServerFilter" before="PRE_AUTH_FILTER" />
        <access-denied-handler ref="oauthAccessDeniedHandler" />
</http>

I have used remote token services and defined the webHook to route the request to the OAuth server

<bean id="tokenServices"  class="org.springframework.security.oauth2.provider.token.RemoteTokenServices">
    <property name="checkTokenEndpointUrl" value="http://localhost:8181/OUTPOST/oauth/check_token"/>
    <property name="clientId" value="atlas"/>
    <property name="clientSecret" value="atlas"/>
</bean>

Configuration for Auth server

@EnableAuthorizationServer
@Configuration
public class AuthorizationServerConfig extends AuthorizationServerConfigurerAdapter {

    private static String REALM="OUTPOST_API";

    @Autowired
    private ClientDetailsService clientService;

    @Autowired
    public AuthorizationServerConfig(AuthenticationManager authenticationManager,RedisConnectionFactory redisConnectionFactory) {
        this.authenticationManager = authenticationManager;
        this.redisTokenStore = new RedisTokenStore(redisConnectionFactory);
    }

//  @Autowired
//  @Qualifier("authenticationManagerBean")
    private AuthenticationManager authenticationManager;

    private TokenStore redisTokenStore;

    @Autowired
    private UserApprovalHandler userApprovalHandler;

    @Autowired
    private RedisConnectionFactory redisConnectionFactory;

    @Override

    public void configure(AuthorizationServerSecurityConfigurer security) throws Exception {

        security.tokenKeyAccess("isAuthenticated()")
                .checkTokenAccess("isAuthenticated()").
        realm(REALM+"/client");

    }

    @Override

    public void configure(ClientDetailsServiceConfigurer clients) throws Exception {

        clients

                .inMemory()
                .withClient("cl1")
                .secret("pwd")
                .authorizedGrantTypes("password", "client_credentials", "refresh_token")
                .authorities("ROLE_CLIENT", "ROLE_ADMIN")
                .scopes("read", "write", "trust")/*
                .resourceIds("sample-oauth")*/              
                .accessTokenValiditySeconds(1000)               
                .refreshTokenValiditySeconds(5000)
                .and()
                .withClient("atlas")
                .secret("atlas");



    }

    @Bean
    @Autowired
    public TokenStore tokenStore(RedisConnectionFactory redisConnectionFactory) {
        this.redisTokenStore = new RedisTokenStore(redisConnectionFactory);
         return this.redisTokenStore;
    }

    @Bean
    public WebResponseExceptionTranslator loggingExceptionTranslator() {
        return new DefaultWebResponseExceptionTranslator() {
            @Override
            public ResponseEntity<OAuth2Exception> translate(Exception e) throws Exception {
                // This is the line that prints the stack trace to the log. You can customise this to format the trace etc if you like
                e.printStackTrace();

                // Carry on handling the exception
                ResponseEntity<OAuth2Exception> responseEntity = super.translate(e);
                HttpHeaders headers = new HttpHeaders();
                headers.setAll(responseEntity.getHeaders().toSingleValueMap());
                OAuth2Exception excBody = responseEntity.getBody();
                return new ResponseEntity<>(excBody, headers, responseEntity.getStatusCode());
            }
        };
    }

    @Override
    public void configure(AuthorizationServerEndpointsConfigurer endpoints) throws Exception {

        endpoints.tokenStore(redisTokenStore).userApprovalHandler(userApprovalHandler)
                .authenticationManager(authenticationManager)
                .exceptionTranslator(loggingExceptionTranslator());
    }

    public void setRedisConnectionFactory(RedisConnectionFactory redisConnectionFactory) {
        this.redisConnectionFactory = redisConnectionFactory;
    }



        @Bean
        public TokenStoreUserApprovalHandler userApprovalHandler(){
            TokenStoreUserApprovalHandler handler = new TokenStoreUserApprovalHandler();
            handler.setTokenStore(redisTokenStore);
            handler.setRequestFactory(new DefaultOAuth2RequestFactory(clientService));
            handler.setClientDetailsService(clientService);
            return handler;
        }

        @Bean
        @Autowired
        public ApprovalStore approvalStore() throws Exception {
            TokenApprovalStore store = new TokenApprovalStore();
            store.setTokenStore(redisTokenStore);
            return store;
        }

        @Bean
        @Primary
        @Autowired
        public DefaultTokenServices tokenServices() {
            DefaultTokenServices tokenServices = new DefaultTokenServices();
            tokenServices.setSupportRefreshToken(true);
            tokenServices.setTokenStore(redisTokenStore);
            return tokenServices;
        }

    }

    @Component
    class MyOAuth2AuthenticationEntryPoint extends OAuth2AuthenticationEntryPoint{}

What I need help with:

The issue is with the support for single .war and multiple API. The issue is the spring config is created at a package level because of which all the APIs in the .war have the same clientID and clientSecret.

How would my OAuth server know, which specific API is being accessed and of which API the hitCount needs to be deducted.

Possible Solution? I was thinks of customizing RemoteTokenService and adding a request parameter at the webHoot URL and then using a filter at OAuth server to get the passed tag (if I may call it that)

Is this even possible? Is there any better approch than this, that doesn't involve all these work arounds?

回答1:

Eureka !! I finally found a way out to resolve this problem.

All you have to do is :

Configuration at Resource server

Instead of using RemoteTokenService make a custom remote token service which appends some data (query parameter) in the generated request.

public class CustomRemoteTokenService implements ResourceServerTokenServices {

protected final Log logger = LogFactory.getLog(getClass());

private RestOperations restTemplate;

private String checkTokenEndpointUrl;

private String clientId;

private String clientSecret;

private String tokenName = "token";

private AccessTokenConverter tokenConverter = new DefaultAccessTokenConverter();

@Autowired
public CustomRemoteTokenService() {
    restTemplate = new RestTemplate();
    ((RestTemplate) restTemplate).setErrorHandler(new DefaultResponseErrorHandler() {
        @Override
        // Ignore 400
        public void handleError(ClientHttpResponse response) throws IOException {
            if (response.getRawStatusCode() != 400) {
                super.handleError(response);
            }
        }
    });
}

public void setRestTemplate(RestOperations restTemplate) {
    this.restTemplate = restTemplate;
}

public void setCheckTokenEndpointUrl(String checkTokenEndpointUrl) {
    this.checkTokenEndpointUrl = checkTokenEndpointUrl;
}

public void setClientId(String clientId) {
    this.clientId = clientId;
}

public void setClientSecret(String clientSecret) {
    this.clientSecret = clientSecret;
}

public void setAccessTokenConverter(AccessTokenConverter accessTokenConverter) {
    this.tokenConverter = accessTokenConverter;
}

public void setTokenName(String tokenName) {
    this.tokenName = tokenName;
}

@Override
public OAuth2Authentication loadAuthentication(String accessToken) throws AuthenticationException, InvalidTokenException {

    /*
     * This code needs to be more dynamic. Every time an API is added we have to add its entry in the if check for now.
     * Should be changed later.
     */
    HttpServletRequest request = Context.getCurrentInstance().getRequest();         
    MultiValueMap<String, String> formData = new LinkedMultiValueMap<String, String>();
    String uri = request.getRequestURI();

    formData.add(tokenName, accessToken);       

    if(request != null) {
        if(uri.contains("API1")) {
            formData.add("api", "1");
        }else if(uri.contains("API2")) {
            formData.add("api", "2");
        } 
    }

    HttpHeaders headers = new HttpHeaders();
    headers.set("Authorization", getAuthorizationHeader(clientId, clientSecret));
    Map<String, Object> map = postForMap(checkTokenEndpointUrl, formData, headers);

    if (map.containsKey("error")) {
        logger.debug("check_token returned error: " + map.get("error"));
        throw new InvalidTokenException(accessToken);
    }



    Assert.state(map.containsKey("client_id"), "Client id must be present in response from auth server");
    return tokenConverter.extractAuthentication(map);
}

@Override
public OAuth2AccessToken readAccessToken(String accessToken) {
    throw new UnsupportedOperationException("Not supported: read access token");
}

private String getAuthorizationHeader(String clientId, String clientSecret) {
    String creds = String.format("%s:%s", clientId, clientSecret);
    try {
        return "Basic " + new String(Base64.encode(creds.getBytes("UTF-8")));
    }
    catch (UnsupportedEncodingException e) {
        throw new IllegalStateException("Could not convert String");
    }
}

private Map<String, Object> postForMap(String path, MultiValueMap<String, String> formData, HttpHeaders headers) {
    if (headers.getContentType() == null) {
        headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
    }
    @SuppressWarnings("rawtypes")
    Map map = restTemplate.exchange(path, HttpMethod.POST,
            new HttpEntity<MultiValueMap<String, String>>(formData, headers), Map.class).getBody();
    @SuppressWarnings("unchecked")
    Map<String, Object> result = map;
    return result;
}

}

By implementing ResourceServerTokenServices you can modify the request that is sent by the resource server to the auth server for authentication and authorization.

configuration at Auth Server

Override the spring security controller. What i mean by overring is make a custom controller so that the request for oauth/check_token is handled by your custom controller and not the spring defined controller.

@RestController
public class CustomCheckTokenEndpoint {

private ResourceServerTokenServices resourceServerTokenServices;

private AccessTokenConverter accessTokenConverter = new DefaultAccessTokenConverter();

protected final Log logger = LogFactory.getLog(getClass());

private WebResponseExceptionTranslator exceptionTranslator = new DefaultWebResponseExceptionTranslator();

@Autowired
KeyHitManager keyHitManager;

public CustomCheckTokenEndpoint(ResourceServerTokenServices resourceServerTokenServices) {
    this.resourceServerTokenServices = resourceServerTokenServices;
}

/**
 * @param exceptionTranslator
 *            the exception translator to set
 */
public void setExceptionTranslator(WebResponseExceptionTranslator exceptionTranslator) {
    this.exceptionTranslator = exceptionTranslator;
}

/**
 * @param accessTokenConverter
 *            the accessTokenConverter to set
 */
public void setAccessTokenConverter(AccessTokenConverter accessTokenConverter) {
    this.accessTokenConverter = accessTokenConverter;
}

@RequestMapping(value = "/oauth/check_token")
@ResponseBody
public Map<String, ?> customCheckToken(@RequestParam("token") String value, @RequestParam("api") int api) {

    OAuth2AccessToken token = resourceServerTokenServices.readAccessToken(value);
    if (token == null) {
        throw new InvalidTokenException("Token was not recognised");
    }

    if (token.isExpired()) {
        throw new InvalidTokenException("Token has expired");
    }

    OAuth2Authentication authentication = resourceServerTokenServices.loadAuthentication(token.getValue());

    Map<String, ?> response = accessTokenConverter.convertAccessToken(token, authentication);

    String clientId = (String) response.get("client_id");
    if (!keyHitManager.isHitAvailble(api,clientId)) {
        throw new InvalidTokenException(
                "Services for this key has been suspended due to daily/hourly transactions limit");
    }

    return response;
}

@ExceptionHandler(InvalidTokenException.class)
public ResponseEntity<OAuth2Exception> handleException(Exception e) throws Exception {
    logger.info("Handling error: " + e.getClass().getSimpleName() + ", " + e.getMessage());
    // This isn't an oauth resource, so we don't want to send an
    // unauthorized code here. The client has already authenticated
    // successfully with basic auth and should just
    // get back the invalid token error.
    @SuppressWarnings("serial")
    InvalidTokenException e400 = new InvalidTokenException(e.getMessage()) {
        @Override
        public int getHttpErrorCode() {
            return 400;
        }
    };
    return exceptionTranslator.translate(e400);
}
}