-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwgsd-client
executable file
·71 lines (58 loc) · 2.16 KB
/
wgsd-client
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
# Copyright (C) 2021 Luca Filipozzi <luca.filipozzi@gmail.com>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# This Source Code Form is "Incompatible With Secondary Licenses", as
# defined by the Mozilla Public License, v. 2.0.
#
#
# This Source Code Form (hereafter 'program') implements a 'client'
# for DNS-based Service Discovery (DNS-SD; RFC 6763 [1]) that extracts
# data from PTR+SRV+TXT records relating to wireguard nodes.
#
#
# usage: wgsd-client --domain example.com
#
#
# [1]: https://tools.ietf.org/html/rfc6763
import click
import dns.resolver
try:
from dns.resolver import resolve as query
except ImportError:
from dns.resolver import query as query
import munch
@click.command()
@click.option('--domain', type=str, required=True)
def main(domain):
nodes = []
# for each target in the PTR record
ptr_answer = query(f'_wireguard._udp.{domain}', 'PTR')
for ptr_target in [x.target for x in ptr_answer]:
node = munch.Munch()
# fetch the SRV record to obtain the UDP port
srv_answer = query(ptr_target, 'SRV')
if len(srv_answer) != 1:
raise Exception('incorrect number of SRV records')
node.port = srv_answer[0].port
node.target = srv_answer[0].target
# fetch the TXT record to obtain the public key
txt_answer = query(node.target, 'TXT')
if len(txt_answer) != 1:
raise Exception('incorrect number of TXT records')
node.update(munch.Munch(x.decode('utf-8').split('=', 1) for x in txt_answer[0].strings))
if node.txtvers != '1':
raise Exception('unknown version')
# fetch the A record to obtain the IPv4 address
a_answer = query(node.target, 'A')
if len(a_answer) != 1:
raise Exception('incorrect number of A records')
node.address = a_answer[0].address
nodes.append(node)
for node in nodes:
print(f'{node.pub} {node.port} {node.target} {node.address}')
if __name__ == '__main__':
main()