Skip to content

Commit

Permalink
skip peers with nil endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
jwhited committed Jan 18, 2021
1 parent d9845d7 commit 7eaacc0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
6 changes: 6 additions & 0 deletions wgsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ func handleSRV(state request.Request, peers []wgtypes.Peer) (int, error) {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
if endpoint == nil {
return nxDomain(state)
}
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
return nxDomain(state)
Expand Down Expand Up @@ -137,6 +140,9 @@ func handleHostOrTXT(state request.Request, peers []wgtypes.Peer) (int, error) {
if strings.EqualFold(
base32.StdEncoding.EncodeToString(peer.PublicKey[:]), pubKey) {
endpoint := peer.Endpoint
if endpoint == nil {
return nxDomain(state)
}
if state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA {
hostRR := getHostRR(state.Name(), endpoint)
if hostRR == nil {
Expand Down
51 changes: 50 additions & 1 deletion wgsd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ func TestWGSD(t *testing.T) {
}
peer2b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer2.PublicKey[:]))
peer2b64 := base64.StdEncoding.EncodeToString(peer2.PublicKey[:])
key3 := [32]byte{}
key3[0] = 3
peer3Allowed, _ := constructAllowedIPs(t, []string{"10.0.0.5/32", "10.0.0.6/32"})
peer3 := wgtypes.Peer{
Endpoint: nil,
PublicKey: key3,
AllowedIPs: peer3Allowed,
}
peer3b32 := strings.ToLower(base32.StdEncoding.EncodeToString(peer3.PublicKey[:]))
p := &WGSD{
Next: test.ErrorHandler(),
Zones: Zones{
Expand All @@ -91,7 +100,7 @@ func TestWGSD(t *testing.T) {
Name: "wg0",
PublicKey: selfKey,
ListenPort: 51820,
Peers: []wgtypes.Peer{peer1, peer2},
Peers: []wgtypes.Peer{peer1, peer2, peer3},
},
},
},
Expand Down Expand Up @@ -205,6 +214,46 @@ func TestWGSD(t *testing.T) {
Qtype: dns.TypeAAAA,
Rcode: dns.RcodeServerFailure,
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeSRV,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeA,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeAAAA,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
{
Qname: fmt.Sprintf("%s._wireguard._udp.example.com.", peer3b32),
Qtype: dns.TypeTXT,
Rcode: dns.RcodeNameError,
Ns: []dns.RR{
test.SOA(soa("example.com.").String()),
},
Answer: []dns.RR{},
Extra: []dns.RR{},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s %s", tc.Qname, dns.TypeToString[tc.Qtype]), func(t *testing.T) {
Expand Down

0 comments on commit 7eaacc0

Please sign in to comment.