Skip to content

Commit

Permalink
add support for redirect limits (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
artemredkin authored Oct 23, 2019
1 parent c1c3da3 commit 51dc885
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 11 deletions.
57 changes: 50 additions & 7 deletions Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,20 @@ public class HTTPClient {
channelEL: EventLoop? = nil,
deadline: NIODeadline? = nil) -> Task<Delegate.Response> {
let redirectHandler: RedirectHandler<Delegate.Response>?
if self.configuration.followRedirects {
switch self.configuration.redirectConfiguration.configuration {
case .follow(let max, let allowCycles):
var request = request
if request.redirectState == nil {
request.redirectState = .init(count: max, visited: allowCycles ? nil : Set())
}
redirectHandler = RedirectHandler<Delegate.Response>(request: request) { newRequest in
self.execute(request: newRequest,
delegate: delegate,
eventLoop: delegateEL,
channelEL: channelEL,
deadline: deadline)
}
} else {
case .disallow:
redirectHandler = nil
}

Expand Down Expand Up @@ -325,7 +330,7 @@ public class HTTPClient {
/// - `305: Use Proxy`
/// - `307: Temporary Redirect`
/// - `308: Permanent Redirect`
public var followRedirects: Bool
public var redirectConfiguration: RedirectConfiguration
/// Default client timeout, defaults to no timeouts.
public var timeout: Timeout
/// Upstream proxy, defaults to no proxy.
Expand All @@ -336,27 +341,27 @@ public class HTTPClient {
public var ignoreUncleanSSLShutdown: Bool

public init(tlsConfiguration: TLSConfiguration? = nil,
followRedirects: Bool = false,
redirectConfiguration: RedirectConfiguration? = nil,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled) {
self.tlsConfiguration = tlsConfiguration
self.followRedirects = followRedirects
self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration()
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
self.decompression = decompression
}

public init(certificateVerification: CertificateVerification,
followRedirects: Bool = false,
redirectConfiguration: RedirectConfiguration? = nil,
timeout: Timeout = Timeout(),
proxy: Proxy? = nil,
ignoreUncleanSSLShutdown: Bool = false,
decompression: Decompression = .disabled) {
self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification)
self.followRedirects = followRedirects
self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration()
self.timeout = timeout
self.proxy = proxy
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
Expand Down Expand Up @@ -439,6 +444,38 @@ extension HTTPClient.Configuration {
self.read = read
}
}

/// Specifies redirect processing settings.
public struct RedirectConfiguration {
enum Configuration {
/// Redirects are not followed.
case disallow
/// Redirects are followed with a specified limit.
case follow(max: Int, allowCycles: Bool)
}

var configuration: Configuration

init() {
self.configuration = .follow(max: 5, allowCycles: false)
}

init(configuration: Configuration) {
self.configuration = configuration
}

/// Redirects are not followed.
public static let disallow = RedirectConfiguration(configuration: .disallow)

/// Redirects are followed with a specified limit.
///
/// - parameters:
/// - max: The maximum number of allowed redirects.
/// - allowCycles: Whether cycles are allowed.
///
/// - warning: Cycle detection will keep all visited URLs in memory which means a malicious server could use this as a denial-of-service vector.
public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { return .init(configuration: .follow(max: max, allowCycles: allowCycles)) }
}
}

private extension ChannelPipeline {
Expand Down Expand Up @@ -488,6 +525,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
case invalidProxyResponse
case contentLengthMissing
case proxyAuthenticationRequired
case redirectLimitReached
case redirectCycleDetected
}

private var code: Code
Expand Down Expand Up @@ -526,4 +565,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
public static let contentLengthMissing = HTTPClientError(code: .contentLengthMissing)
/// Proxy Authentication Required.
public static let proxyAuthenticationRequired = HTTPClientError(code: .proxyAuthenticationRequired)
/// Redirect Limit reached.
public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached)
/// Redirect Cycle detected.
public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected)
}
32 changes: 31 additions & 1 deletion Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ extension HTTPClient {
/// Request body, defaults to no body.
public var body: Body?

struct RedirectState {
var count: Int
var visited: Set<URL>?
}

var redirectState: RedirectState?

/// Create HTTP request.
///
/// - parameters:
Expand Down Expand Up @@ -152,6 +159,8 @@ extension HTTPClient {
self.host = host
self.headers = headers
self.body = body

self.redirectState = nil
}

/// Whether request will be executed using secure socket.
Expand Down Expand Up @@ -813,6 +822,26 @@ internal struct RedirectHandler<ResponseType> {
}

func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise<ResponseType>) {
var nextState: HTTPClient.Request.RedirectState?
if var state = request.redirectState {
guard state.count > 0 else {
return promise.fail(HTTPClientError.redirectLimitReached)
}

state.count -= 1

if var visited = state.visited {
guard !visited.contains(redirectURL) else {
return promise.fail(HTTPClientError.redirectCycleDetected)
}

visited.insert(redirectURL)
state.visited = visited
}

nextState = state
}

let originalRequest = self.request

var convertToGet = false
Expand Down Expand Up @@ -841,7 +870,8 @@ internal struct RedirectHandler<ResponseType> {
}

do {
let newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
var newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
newRequest.redirectState = nextState
return self.execute(newRequest).futureResult.cascade(to: promise)
} catch {
return promise.fail(error)
Expand Down
10 changes: 10 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ internal final class HttpBinHandler: ChannelInboundHandler {
headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite1":
var headers = HTTPHeaders()
headers.add(name: "Location", value: "/redirect/infinite2")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
case "/redirect/infinite2":
var headers = HTTPHeaders()
headers.add(name: "Location", value: "/redirect/infinite1")
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
return
// Since this String is taken from URL.path, the percent encoding has been removed
case "/percent encoded":
if req.method != .GET {
Expand Down
2 changes: 2 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ extension HTTPClientTests {
("testEventLoopArgument", testEventLoopArgument),
("testDecompression", testDecompression),
("testDecompressionLimit", testDecompressionLimit),
("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit),
("testCountRedirectLimit", testCountRedirectLimit),
]
}
}
38 changes: 35 additions & 3 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class HTTPClientTests: XCTestCase {
let httpBin = HTTPBin(ssl: false)
let httpsBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand All @@ -149,7 +149,7 @@ class HTTPClientTests: XCTestCase {
func testHttpHostRedirect() throws {
let httpBin = HTTPBin(ssl: false)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, followRedirects: true))
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
Expand Down Expand Up @@ -526,7 +526,7 @@ class HTTPClientTests: XCTestCase {
let httpBin = HTTPBin()
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 5)
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup),
configuration: HTTPClient.Configuration(followRedirects: true))
configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true)))
defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully())
Expand Down Expand Up @@ -568,6 +568,7 @@ class HTTPClientTests: XCTestCase {
func testDecompression() throws {
let httpBin = HTTPBin(compress: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .none)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
Expand Down Expand Up @@ -603,6 +604,7 @@ class HTTPClientTests: XCTestCase {
func testDecompressionLimit() throws {
let httpBin = HTTPBin(compress: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(decompression: .enabled(limit: .ratio(10))))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
Expand All @@ -626,4 +628,34 @@ class HTTPClientTests: XCTestCase {
XCTFail("Unexptected error: \(error)")
}
}

func testLoopDetectionRedirectLimit() throws {
let httpBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
}

XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.redirectCycleDetected)
}
}

func testCountRedirectLimit() throws {
let httpBin = HTTPBin(ssl: true)
let httpClient = HTTPClient(eventLoopGroupProvider: .createNew,
configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: true)))

defer {
XCTAssertNoThrow(try httpClient.syncShutdown())
XCTAssertNoThrow(try httpBin.shutdown())
}

XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(httpBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.redirectLimitReached)
}
}
}

0 comments on commit 51dc885

Please sign in to comment.