1
0
forked from extern/smegmesh

Compare commits

..

98 Commits

Author SHA1 Message Date
Tim Beatham
e26558ce90
Added disclaimer 2024-08-11 13:09:05 +01:00
Tim Beatham
92b57bf610
Update README.md 2024-08-11 13:04:24 +01:00
Tim Beatham
366ffd2535
Create LICENSE 2024-08-11 13:03:17 +01:00
Tim Beatham
0c537395f0 Fixed build errors 2024-08-11 12:58:39 +01:00
Tim Beatham
1c0a559ea1 Changed package name from robin -> cplane 2024-08-11 12:25:52 +01:00
Tim Beatham
c3241c2764 Improving the command help messages 2024-08-11 12:24:15 +01:00
Tim Beatham
83e7f3c004 Updating the README 2024-08-11 12:16:31 +01:00
Tim Beatham
b8585b3a76 main 2024-01-18 16:13:41 +00:00
Tim Beatham
a619838e9e main
- DONE :)
2024-01-18 16:08:41 +00:00
Tim Beatham
b6fe352553 main
- Done
2024-01-18 10:50:59 +00:00
Tim Beatham
664e54b710 main
- Finished for demo
2024-01-18 10:36:55 +00:00
Tim Beatham
ad4d461332 main
- Preparation for demo
2024-01-18 09:59:37 +00:00
Tim Beatham
901674a5e3 main
- Prep for demo
2024-01-17 16:21:49 +00:00
Tim Beatham
915263e49a main
- Ready for presentation
2024-01-17 14:52:09 +00:00
Tim Beatham
41d41694a6 main
- Preparing for demo
2024-01-17 13:50:18 +00:00
Tim Beatham
3f82ef9cd7 Bugfix 2024-01-16 16:59:07 +00:00
Tim Beatham
7e6f2563c7 fixed bug for demo 2024-01-16 16:25:32 +00:00
Tim Beatham
c91e6e7f68 Added missing configuration files 2024-01-16 14:59:05 +00:00
Tim Beatham
ed525c045a
Merge pull request #85 from tim-beatham/81-separate-synchronisation-into-independent-processes
Submitting
2024-01-05 18:22:48 +00:00
Tim Beatham
9a30f4d5cb Submitting 2024-01-05 18:22:05 +00:00
Tim Beatham
b294f116a2
Merge pull request #84 from tim-beatham/81-separate-synchronisation-into-independent-processes
81-seperate-processes
2024-01-05 17:00:29 +00:00
Tim Beatham
f647c1b806 81-seperate-processes
Prep for submission
2024-01-05 16:59:02 +00:00
Tim Beatham
0136e44b36
Merge pull request #83 from tim-beatham/81-separate-synchronisation-into-independent-processes
81 separate synchronisation into independent processes
2024-01-05 13:05:01 +00:00
Tim Beatham
a55dadf088 81-seperate-synchronisation-into-independent-procs
- Neaten code
2024-01-05 12:59:13 +00:00
Tim Beatham
0ec5156e59 81-procs
- fixed issue where route not deleting if mesh only one
2024-01-05 00:14:25 +00:00
Tim Beatham
2b73d241b6 81-serparate-procs
- nil dereference again
2024-01-04 22:29:30 +00:00
Tim Beatham
69b1790bb6 81-processes
- issue with client client traversal
2024-01-04 22:08:14 +00:00
Tim Beatham
4a92743880 81-seperate-sync
- build error
2024-01-04 21:48:54 +00:00
Tim Beatham
038393052c 81-seperate-synchronisation-into-independent-proc
- build error
2024-01-04 21:47:29 +00:00
Tim Beatham
5efff2314b 81-separate-synchronisation-into-independent-process
- nil dereference when no joins
2024-01-04 21:45:28 +00:00
Tim Beatham
1f8d229076 81-seperate-synchronisation-into-independent-process
- nil dereference due to concurrency issues (the method shouldn't be
  concurrent)
2024-01-04 21:16:33 +00:00
Tim Beatham
a0e7a4a644 81-seperateprocesses-into-independent-processes
- Fixed errors
2024-01-04 13:15:29 +00:00
Tim Beatham
f9b8b85ec3 81-seperate-synchronisation
- Removed authentication.proto
2024-01-04 13:12:33 +00:00
Tim Beatham
59d8ae4334 81-seperate-synchronisation
- More code comments
2024-01-04 13:12:07 +00:00
Tim Beatham
02dfd73e08 81-seperate-synchronisation-into-independent
- Separated synchronisation calls into independent processes
- Commented code for submission
2024-01-04 13:10:08 +00:00
Tim Beatham
9818645299
Merge pull request #82 from tim-beatham/bugfix-node-not-leving
bugfix-node-not-leaving
2024-01-04 00:24:58 +00:00
Tim Beatham
1f0914e2df bugfix-node-not-leaving
- Add lock when perform synchronisation on concurrent access
2024-01-04 00:23:20 +00:00
Tim Beatham
efb40d65de
Merge pull request #80 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:32:09 +00:00
Tim Beatham
27e00196cd main
- Not waiting in the waitgroup
2024-01-02 20:31:24 +00:00
Tim Beatham
4543205703
Merge pull request #79 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:21:27 +00:00
Tim Beatham
dea6f1a22d main
- error in code invalid check for nil
2024-01-02 20:19:34 +00:00
Tim Beatham
4d19da6727
Merge pull request #78 from tim-beatham/bugfix-node-not-leving
main
2024-01-02 20:12:10 +00:00
Tim Beatham
913de57568 main
- Fixed bug
2024-01-02 20:11:11 +00:00
Tim Beatham
8a5673e303
Merge pull request #77 from tim-beatham/bugfix-node-not-leving
bugfix node not leaving
2024-01-02 19:43:04 +00:00
Tim Beatham
ce829114b1 bugfix
- on synchornisation node is not leaving mesh
2024-01-02 19:41:20 +00:00
Tim Beatham
05cc287e31
Merge pull request #76 from tim-beatham/74-perform-dad
- Fixing DNS error
2024-01-02 00:16:45 +00:00
Tim Beatham
cd844ff46e - Fixing DNS error 2024-01-02 00:15:23 +00:00
Tim Beatham
4b9406a920
Merge pull request #75 from tim-beatham/74-perform-dad
74-perform-dad
2024-01-02 00:14:37 +00:00
Tim Beatham
d0b1913796 74-perform-dad
- Fixing nil pointer dereference
2024-01-02 00:13:04 +00:00
Tim Beatham
90cfe820d2 - Fixing errors with stale paths 2024-01-02 00:09:31 +00:00
Tim Beatham
8a49809855 74-perform-dad
- Adding go.sum to fix errors
2024-01-01 23:59:04 +00:00
Tim Beatham
dbc18bddc6 74-perform-dad
- Performing DAD to check if IPv6 address present before adding
  outselves to mesh
- Changing name from wgmesh to smegmesh
2024-01-01 23:55:50 +00:00
Tim Beatham
14f335af74
Merge pull request #73 from tim-beatham/72-pull-rate-in-configuration
72 pull rate in configuration
2023-12-31 14:26:34 +00:00
Tim Beatham
36e82dba47 72-pull-rate-in-configuration
- Refactored pull rate into the configuration
- code freeze so no more code changes
2023-12-31 14:25:06 +00:00
Tim Beatham
3cc87bc252 72-pull-rate-in-configuration
- Updated examples
2023-12-31 12:47:45 +00:00
Tim Beatham
a9ed7c0a20 72-pull-rate-in-configuration
- Removing libp2p reference
2023-12-31 12:47:45 +00:00
Tim Beatham
fd29af73e3 72-pull-rate-in-configuration
- Added pull rate to configuration (finally) so this can
be modified by an administrator.
2023-12-31 12:47:45 +00:00
Tim Beatham
9e1058e0f2 72-pull-rate-in-configuration
- Added the pull rate to the configuration file
2023-12-31 12:47:45 +00:00
Tim Beatham
c29eb197f3
Merge pull request #71 from tim-beatham/66-ipv6-address-not-conforming-to-spec
66 ipv6 address not conforming to spec
2023-12-30 22:26:53 +00:00
Tim Beatham
1a9d9d61ad 66-ipv6-address-not-conforming-to-spec
- Missing commit
2023-12-30 22:26:08 +00:00
Tim Beatham
6954608c32 66-ipv6-address-not-confirming-to-spec
- UUID is not random just a name generator needs changing to shortuuid
- When in multiple meshes there is no wait group
2023-12-30 22:24:43 +00:00
Tim Beatham
2e6aed6f93 main
- Fixing issue with nil pointer de-reference due to bad design of mesh
  manager.
- Going forward all references to GetSelf should be depracated. It
  introduces a race condition when leaving a mesh network
2023-12-30 00:44:57 +00:00
Tim Beatham
b0893a0b8e
Merge pull request #69 from tim-beatham/60-unit-test-crdt-data-store
60-unit-test-crdt-data-store
2023-12-29 22:06:20 +00:00
Tim Beatham
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
Tim Beatham
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
Tim Beatham
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
Tim Beatham
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
Tim Beatham
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
Tim Beatham
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
Tim Beatham
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
Tim Beatham
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
Tim Beatham
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
Tim Beatham
311a15363a
Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
Tim Beatham
255d3c8b39 66-improve-graph-dot-tool
- Showing services a node provides
- Showing all meshes not just one
- Showing the default route
2023-12-25 01:25:20 +00:00
Tim Beatham
41899c5831 66-improve-graph-dot-tool
Improving the graph dot tool so that it shows all
meshes
2023-12-25 01:10:11 +00:00
Tim Beatham
fe4ca66ff6
Merge pull request #65 from tim-beatham/64-2p-set-unit-test
64 2p set unit test
2023-12-22 23:58:59 +00:00
Tim Beatham
0b91ba744a 61-improve-unit-test-coverage
- Provided unit tests for g_map and 2p_map
2023-12-22 23:57:10 +00:00
Tim Beatham
67483c2a90 64-unit-test-two-phase-set
Provide unit tests for two phase set to make it more
transparent what exactly they are doing.
2023-12-22 23:57:10 +00:00
Tim Beatham
af26e81bd3
Merge pull request #63 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:52:46 +00:00
Tim Beatham
0cc3141b58 61-improve-unit-testing-coverage
- Added missing files to commit
2023-12-22 21:49:47 +00:00
Tim Beatham
186acbe915
Merge pull request #62 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:49:06 +00:00
Tim Beatham
ceb43a1db1 61-improve-unit-testing-coverage
- Got unit tests passing
- Improved manager unit tests
2023-12-22 21:47:56 +00:00
Tim Beatham
bed59f120f
Merge pull request #60 from tim-beatham/59-error-when-peer-not-selected
59-error-when-peer-not-selected
2023-12-22 19:12:30 +00:00
Tim Beatham
8aab4e99d8 59-error-when-peer-not-selected
In the CLI when the peer is not selected
as the type throwing an error stating
either client or peer must be selected
2023-12-22 19:08:20 +00:00
Tim Beatham
cf4be1ccab
Merge pull request #58 from tim-beatham/bugfix-pull-only
Bugfix pull only
2023-12-22 18:49:09 +00:00
Tim Beatham
6ed32f3a79 bugfix-push-pull
Organised groups as a tree so that there
isn't a limit to dissemination
2023-12-19 00:50:17 +00:00
Tim Beatham
b6199892f0 bugfix-pull-only
Bugfix with inter-cluster communication pull not working
2023-12-18 22:17:46 +00:00
Tim Beatham
ad22f04b0d bugfix-pull-only
After certain period of time if no changes have
occurred then pull
2023-12-18 20:45:56 +00:00
Tim Beatham
092d9a4af5 checking-latency-for-pull-only 2023-12-17 09:44:32 +00:00
Tim Beatham
19abf712a6 Fixing bug with nodes being removed 2023-12-12 12:45:41 +00:00
Tim Beatham
b296e1f45a
Merge pull request #57 from tim-beatham/55-cli-option-for-peer-type
55-cli-optionifor-peer-type
2023-12-12 12:00:42 +00:00
Tim Beatham
2dc89d171b 55-cli-optionifor-peer-type
- Ability to specify WireGuard keepalive in the CLI formatter
- Ability to specify publicly routeable endpoint
- Ability to specify whether to advetise routes into the mesh,
and whether to advertise default routes.
2023-12-12 11:58:47 +00:00
Tim Beatham
13bea10638 main - bugfix
- Nodes not being removed when deleted because when node gossips again
  it is readded.
- Keep track of highest vector clock we have removed and used this as a
  mark for determining if something is stale.
2023-12-11 11:09:02 +00:00
Tim Beatham
3222d7e388 main - adding WireGuard stats to JSON objects
- Adding WireGuard stats through to IPC calls so that they can be used
by the API
2023-12-11 09:55:25 +00:00
Tim Beatham
1789d203f6 main - fix default routing being deleted
Default route keeps fluctuating on configuration
update.
2023-12-10 23:35:00 +00:00
Tim Beatham
a5074a536e main - BUGFIX
- segfault BUGFIX
2023-12-10 22:31:24 +00:00
Tim Beatham
acb90a5679 main - go.sum should be tracked into the git
- go.sum should be contained in the git history
2023-12-10 22:11:09 +00:00
Tim Beatham
27ec23f133
Merge pull request #54 from tim-beatham/53-run-commands-pre-up-and-post-down
53-run-commands-pre-up-and-post-down
2023-12-10 19:22:59 +00:00
95 changed files with 5649 additions and 2735 deletions

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

674
LICENSE Normal file
View File

@ -0,0 +1,674 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

View File

@ -1,2 +1,97 @@
# wgmesh # smegmesh
WireGuard VPN Mesh Management
## Disclaimer
Submitted to fill the requirements of Msci (Hons) Computer Science at the School of Computer Science, University of St Andrews.
## License
This repository is licensed under the MIT License. See the [LICENSE](./LICENSE) file for more details.
## Overview
Distributed WireGuard mesh management. This tool helps to configure WireGuard
networks in a mesh topology such that there is no single point of failure.
The tool aims to set-up mesh networks with minimal knowledge and
configuration of WireGuard.
The idea being that a node can take up one of two roles in the network, a
peer or a client. A peer is publicly accessible and must have IPv6 forwarding
enabled. Peer's responsibility is routing traffic on behalf of clients
associated with it.
Whereas, a client hides behind a private endpoint in which all packets are
routed through the peer. A client must enable the flat `keepAliveWg` to
ensure that its associated peer learns about any NAT mappings that change.
IPv6 is used in the overlay to make use of the larger address space.
A node hashes it's WireGuard public key to create an identifier
(the last 64-bits of the IPv6 address) and the mesh-id is hashed into
the first 64-bits of the IPv6 address to create the locator.
A node (both client and a peer) can be in multiple meshes at the same
time. In which case the node can optionally choose to act as a bridge
and forward packets between the two meshes. Through this it is possible
to define complex topologies. To route between meshes multiple hops away
a simple link-state protocol is adopted (similar to RIP) in which the
path length (number of meshes) is used to determine the shortest path.
Redundant routing is possible to create multiple exit points to the same
mesh network. In which case consistent hashing is performed to split traffic
between the exit points.
## Message Dissemination
A variant of the gossip protocol is used for message dissemination. Each peer
in the network is ordered lexicographically ordered by their public key.
The node with the lexicographically lowest public key is used as the leader
of the mesh. Every `heartBeatInterval` disseminates a refresh message
throughout the entirety of the group in order to prune nodes that may
have prematurely died.
If after `3 * heartBeatInterval` a node has not received a dissemination
message then the node prunes the leader and expects one from the next
lexicographically lowest public key.
To 'merge' updates and reconcile any conflicts a Conflict Free Replicated
Data Type (CRDT) is implemented. Consisting of an add and remove set.
Where a node is in the group if it is in the add set and there is either
no entry in the remove set or the timestamp in the remove set has a lower
vector clock value.
## Performance
This prototype has been tested to a scale of 3000 peers in the network.
Furthermore, the fault-tolerance has been tested to a scale 3000 nodes
to the order of 20 seconds for the entire network and 12 seconds
for the 99 percentile.
## Installation
To build the project do: `go build -v ./...`. A Docker file is provided
to get started.
To build with the Dockerfile:
`docker build -t smegmesh-base ./`
Then run an example topology in the examples folder. For example:
`cd examples/simple && docker-compose up -d`
## Tools
### Smegd
Smegmesh requires the daemon process to be running (smegd) which also takes
a configuration.yaml file. An example yaml configuration file is provided in
examples/simple/shared/configuration.
### Smegctl
Smegctl is a CLI tool to create, join, visualise and administer networks.
### Api
An api is provided to invoke functions to create, join, visualise and administer
networks. This could be used to create an application that allows a user
to configure the networks.
### Dns
A dns server is provided to resolve an alias into an IPv6 address.

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
"github.com/tim-beatham/wgmesh/pkg/api" "github.com/tim-beatham/smegmesh/pkg/api"
) )
func main() { func main() {

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns" smegdns "github.com/tim-beatham/smegmesh/pkg/dns"
) )
func main() { func main() {

407
cmd/smegctl/main.go Normal file
View File

@ -0,0 +1,407 @@
package main
import (
"fmt"
ipcRpc "net/rpc"
"os"
"github.com/akamensky/argparse"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
graph "github.com/tim-beatham/smegmesh/pkg/dot"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func createMesh(client *ipc.SmegmeshIpc, args *ipc.NewMeshArgs) {
var reply string
err := client.CreateMesh(args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func listMeshes(client *ipc.SmegmeshIpc) {
reply := new(ipc.ListMeshReply)
err := client.ListMeshes(reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return
}
for _, meshId := range reply.Meshes {
fmt.Println(meshId)
}
}
func joinMesh(client *ipc.SmegmeshIpc, args ipc.JoinMeshArgs) {
var reply string
err := client.JoinMesh(args, &reply)
if err != nil {
fmt.Print(err.Error())
}
fmt.Println(reply)
}
func leaveMesh(client *ipc.SmegmeshIpc, meshId string) {
var reply string
err := client.LeaveMesh(meshId, &reply)
if err != nil {
fmt.Print(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipc.SmegmeshIpc) {
listMeshesReply := new(ipc.ListMeshReply)
err := client.ListMeshes(listMeshesReply)
if err != nil {
fmt.Print(err.Error())
return
}
meshes := make(map[string][]ctrlserver.MeshNode)
for _, meshId := range listMeshesReply.Meshes {
var meshReply ipc.GetMeshReply
err := client.GetMesh(meshId, &meshReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes[meshId] = meshReply.Nodes
}
dotGenerator := graph.NewMeshGraphConverter(meshes)
dot, err := dotGenerator.Generate()
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(dot)
}
func queryMesh(client *ipc.SmegmeshIpc, meshId, query string) {
var reply string
args := ipc.QueryMesh{
MeshId: meshId,
Query: query,
}
err := client.Query(args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func putDescription(client *ipc.SmegmeshIpc, meshId, description string) {
var reply string
err := client.PutDescription(ipc.PutDescriptionArgs{
MeshId: meshId,
Description: description,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipc.SmegmeshIpc, meshid, alias string) {
var reply string
err := client.PutAlias(ipc.PutAliasArgs{
MeshId: meshid,
Alias: alias,
}, &reply)
if err != nil {
fmt.Print(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipc.SmegmeshIpc, meshId, service, value string) {
var reply string
err := client.PutService(ipc.PutServiceArgs{
MeshId: meshId,
Service: service,
Value: value,
}, &reply)
if err != nil {
fmt.Print(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipc.SmegmeshIpc, meshId, service string) {
var reply string
err := client.DeleteService(ipc.DeleteServiceArgs{
MeshId: meshId,
Service: service,
}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("smgctl",
"smegctl Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Help: "Role in the mesh network. A peer is publicly route-able, whereas a client sits behind a private endpoint",
})
var newMeshKeepAliveWg *int = newMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall hole-punching",
})
var newMeshAdvertiseRoutes *bool = newMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var newMeshAdvertiseDefaults *bool = newMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var joinMeshId *string = joinMeshCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{
Required: true,
Help: "IP address of the bootstrapping node to join through",
})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{
Help: "Publicly routeable endpoint to advertise within the mesh",
})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Help: "Role in the mesh network. A value of peer means that the node is publicly route-able acting as a router " +
"for clients to route packets through. A client sits behind a private endpoint and routes traffic through a single " +
"endpoint",
})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{
Default: 0,
Help: "WireGuard port to use to the interface. A default of 0 uses an unused ephmeral port.",
})
var joinMeshKeepAliveWg *int = joinMeshCmd.Int("k", "KeepAliveWg", &argparse.Options{
Default: 0,
Help: "WireGuard KeepAlive value for NAT traversal and firewall ho;lepunching",
})
var joinMeshAdvertiseRoutes *bool = joinMeshCmd.Flag("a", "advertise", &argparse.Options{
Help: "Advertise routes to other mesh network into the mesh",
})
var joinMeshAdvertiseDefaults *bool = joinMeshCmd.Flag("d", "defaults", &argparse.Options{
Help: "Advertise ::/0 into the mesh network",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to leave",
})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the mesh to query",
})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{
Required: true,
Help: "JMESPath Query Of The Mesh Network To Query",
})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{
Required: true,
Help: "Description of the node in the mesh",
})
var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address",
})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to advertise in the mesh network",
})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{
Required: true,
Help: "Value of the service to advertise in the mesh network",
})
var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true,
Help: "Key of the service to remove",
})
var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
err := parser.Parse(os.Args)
if err != nil {
fmt.Print(parser.Usage(err))
return
}
client, err := ipc.NewClientIpc()
if err != nil {
panic(err)
}
if newMeshCmd.Happened() {
args := &ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
WgPort: *newMeshPort,
KeepAliveWg: *newMeshKeepAliveWg,
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes,
},
}
createMesh(client, args)
}
if listMeshCmd.Happened() {
listMeshes(client)
}
if joinMeshCmd.Happened() {
args := ipc.JoinMeshArgs{
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
WgPort: *joinMeshPort,
KeepAliveWg: *joinMeshKeepAliveWg,
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes,
},
}
joinMesh(client, args)
}
if getGraphCmd.Happened() {
getGraph(client)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
if queryMeshCmd.Happened() {
queryMesh(client, *queryMeshMeshId, *queryMeshQuery)
}
if putDescriptionCmd.Happened() {
putDescription(client, *descriptionMeshId, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *aliasMeshId, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceMeshId, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceMeshid, *deleteServiceKey)
}
}

View File

@ -1,34 +1,34 @@
package main package main
import ( import (
"net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" robin "github.com/tim-beatham/smegmesh/pkg/cplane"
"github.com/tim-beatham/wgmesh/pkg/ipc" ctrlserver "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/sync"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
func main() { func main() {
if len(os.Args) != 2 { if len(os.Args) != 2 {
logging.Log.WriteErrorf("Did not provide configuration") logging.Log.WriteErrorf("Did not provide configuration")
return return
} }
conf, err := conf.ParseDaemonConfiguration(os.Args[1]) configuration, err := conf.ParseDaemonConfiguration(os.Args[1])
if err != nil { if err != nil {
logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error()) logging.Log.WriteErrorf("Could not parse configuration: %s", err.Error())
return return
} }
logging.SetLogger(logging.NewLogrusLogger(configuration.LogLevel))
client, err := wgctrl.New() client, err := wgctrl.New()
if err != nil { if err != nil {
@ -36,34 +36,24 @@ func main() {
return return
} }
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
var syncRequester sync.SyncRequester
var syncer sync.Syncer
ctrlServerParams := ctrlserver.NewCtrlServerParams{ ctrlServerParams := ctrlserver.NewCtrlServerParams{
Conf: conf, Conf: configuration,
CtrlProvider: &robinRpc, CtrlProvider: &robinRpc,
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
OnDelete: func(mp mesh.MeshProvider) {
syncer.SyncMeshes()
},
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer
syncRequester = sync.NewSyncRequester(ctrlServer) if err != nil {
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) panic(err)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, syncer) }
keepAlive := timer.NewTimestampScheduler(ctrlServer)
syncProvider.MeshManager = ctrlServer.MeshManager
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -77,16 +67,11 @@ func main() {
return return
} }
logging.Log.WriteInfof("Running IPC Handler") logging.Log.WriteInfof("running ipc handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go keepAlive.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("closing resources")
syncScheduler.Stop()
keepAlive.Stop()
ctrlServer.Close() ctrlServer.Close()
client.Close() client.Close()
} }

View File

@ -1,331 +0,0 @@
package main
import (
"fmt"
ipcRpc "net/rpc"
"os"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
WgPort int
Endpoint string
Role string
}
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
WgPort: args.WgPort,
Endpoint: args.Endpoint,
Role: args.Role,
}
err := args.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil {
return err.Error()
}
return reply
}
func listMeshes(client *ipcRpc.Client) {
reply := new(ipc.ListMeshReply)
err := client.Call("IpcHandler.ListMeshes", "", &reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return
}
for _, meshId := range reply.Meshes {
fmt.Println(meshId)
}
}
type JoinMeshParams struct {
Client *ipcRpc.Client
MeshId string
IpAddress string
IfName string
WgPort int
Endpoint string
Role string
}
func joinMesh(params *JoinMeshParams) string {
var reply string
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
Port: params.WgPort,
Role: params.Role,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
if err != nil {
return err.Error()
}
return reply
}
func leaveMesh(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func enableInterface(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.EnableInterface", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getGraph(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("IpcHandler.GetDOT", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func queryMesh(client *ipcRpc.Client, meshId, query string) {
var reply string
err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putDescription: puts updates the description about the node to the meshes
func putDescription(client *ipcRpc.Client, description string) {
var reply string
err := client.Call("IpcHandler.PutDescription", &description, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshMeshId *string = queryMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var queryMeshQuery *string = queryMeshCmd.String("q", "query", &argparse.Options{Required: true})
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
fmt.Print(parser.Usage(err))
return
}
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
fmt.Println(err.Error())
return
}
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
Role: *newMeshRole,
}))
}
if listMeshCmd.Happened() {
listMeshes(client)
}
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole,
}))
}
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
if enableInterfaceCmd.Happened() {
enableInterface(client, *enableInterfaceMeshId)
}
if leaveMeshCmd.Happened() {
leaveMesh(client, *leaveMeshMeshId)
}
if queryMeshCmd.Happened() {
queryMesh(client, *queryMeshMeshId, *queryMeshQuery)
}
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
}

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15 interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 heartBeatTime: 10
pruneTime: 20 pruneTime: 20

32
conf/client.yaml Normal file
View File

@ -0,0 +1,32 @@
# Paths to the certificates modify
# if not running from Smegmesh
certificatePath: "./cert/cert.pem"
privateKeyPath: "./cert/priv.pem"
caCertificatePath: "./cert/cacert.pem"
skipCertVerification: true
# timeout is the configured grpc timeout
timeout: 5
# gRPC port to run the solution
gRPCPort: 4000
# stubWg: whether to install WireGuard configurations
# if true just tests the control plane
stubWg: false
heartbeatInterval: 60
branch: 3
pullInterval: 20
infectionCount: 3
interClusterChance: 0.15
syncInterval: 10
clusterSize: 64
logLevel: "info"
baseConfiguration:
# ipDiscovery: specifies how to find your IP address
ipDiscovery: "outgoing"
# alternative to ipDiscovery specify an actual endpoint yourself with publicEndpoint: "xxxx"
# role is the role that you are playing (peer | client)
# peers can only bootstrap meshes
role: "client"
# advertise meshes to other meshes
advertiseRoute: true
# advertise default routes
advertiseDefaults: true

33
conf/peer.yaml Normal file
View File

@ -0,0 +1,33 @@
# Paths to the certificates modify
# if not running from Smegmesh
certificatePath: "./cert/cert.pem"
privateKeyPath: "./cert/priv.pem"
caCertificatePath: "./cert/cacert.pem"
skipCertVerification: true
# timeout is the configured grpc timeout
timeout: 5
# gRPC port to run the solution
gRPCPort: 4000
# stubWg: whether to install WireGuard configurations
# if true just tests the control plane
stubWg: false
heartbeatInterval: 60
branch: 3
pullInterval: 20
infectionCount: 3
interClusterChance: 0.15
syncInterval: 2
clusterSize: 64
logLevel: "info"
baseConfiguration:
# ipDiscovery: specifies how to find your IP address
ipDiscovery: "outgoing"
# alternative to ipDiscovery specify an actual endpoint yourself with publicEndpoint: "xxxx"
# role is the role that you are playing (peer | client)
# peers can only bootstrap meshes
role: "peer"
# advertise meshes to other meshes
advertiseRoute: true
# advertise default routes
advertiseDefaults: true

View File

@ -1,95 +0,0 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -1,14 +0,0 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -1,14 +1,14 @@
version: '3' version: '3'
networks: networks:
net-1: net-1:
driver: bridge enable_ipv6: true
ipam: ipam:
driver: default driver: default
config: config:
- subnet: 10.89.0.0/17 - subnet: 2001:db8::/64
services: services:
wg-1: wg-1:
image: wg-mesh-base:latest image: smegmesh-base:latest
cap_add: cap_add:
- NET_ADMIN - NET_ADMIN
- NET_RAW - NET_RAW
@ -17,9 +17,12 @@ services:
- net-1 - net-1
volumes: volumes:
- ./shared:/shared - ./shared:/shared
command: "wgmeshd /shared/configuration.yaml" command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
- net.ipv6.conf.all.disable_ipv6=0
wg-2: wg-2:
image: wg-mesh-base:latest image: smegmesh-base:latest
cap_add: cap_add:
- NET_ADMIN - NET_ADMIN
- NET_RAW - NET_RAW
@ -28,9 +31,12 @@ services:
- net-1 - net-1
volumes: volumes:
- ./shared:/shared - ./shared:/shared
command: "wgmeshd /shared/configuration.yaml" command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
- net.ipv6.conf.all.disable_ipv6=0
wg-3: wg-3:
image: wg-mesh-base:latest image: smegmesh-base:latest
cap_add: cap_add:
- NET_ADMIN - NET_ADMIN
- NET_RAW - NET_RAW
@ -39,4 +45,7 @@ services:
- net-1 - net-1
volumes: volumes:
- ./shared:/shared - ./shared:/shared
command: "wgmeshd /shared/configuration.yaml" command: "smegd /shared/configuration.yaml"
sysctls:
- net.ipv6.conf.all.forwarding=1
- net.ipv6.conf.all.disable_ipv6=0

View File

@ -1,14 +1,32 @@
certificatePath: "/wgmesh/cert/cert.pem" # Paths to the certificates modify
privateKeyPath: "/wgmesh/cert/priv.pem" # if not running from Smegmesh
caCertificatePath: "/wgmesh/cert/cacert.pem" certificatePath: "./cert/cert.pem"
privateKeyPath: "./cert/priv.pem"
caCertificatePath: "./cert/cacert.pem"
skipCertVerification: true skipCertVerification: true
# timeout is the configured grpc timeout
timeout: 5 timeout: 5
gRPCPort: "21906" # gRPC port to run the solution
advertiseRoutes: true gRPCPort: 4000
clusterSize: 32 # stubWg: whether to install WireGuard configurations
syncRate: 1 # if true just tests the control plane
interClusterChance: 0.15 stubWg: false
branchRate: 3 heartbeatInterval: 60
branch: 3
pullInterval: 20
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 interClusterChance: 0.15
pruneTime: 20 syncInterval: 2
clusterSize: 64
logLevel: "info"
baseConfiguration:
# ipDiscovery: specifies how to find your IP address
ipDiscovery: "outgoing"
# alternative to ipDiscovery specify an actual endpoint yourself with publicEndpoint: "xxxx"
# role is the role that you are playing (peer | client)
# peers can only bootstrap meshes
role: "peer"
# advertise meshes to other meshes
advertiseRoute: true
# advertise default routes
advertiseDefaults: true

16
go.mod
View File

@ -1,14 +1,18 @@
module github.com/tim-beatham/wgmesh module github.com/tim-beatham/smegmesh
go 1.21.3 go 1.21.3
require ( require (
github.com/akamensky/argparse v1.4.0 github.com/akamensky/argparse v1.4.0
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5 github.com/jsimonetti/rtnetlink v1.3.5
github.com/lithammer/shortuuid v3.0.0+incompatible
github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0 golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
@ -24,7 +28,6 @@ require (
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
@ -42,10 +45,13 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.15.0 // indirect golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/sync v0.3.0 // indirect golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect
) )

141
go.sum Normal file
View File

@ -0,0 +1,141 @@
github.com/akamensky/argparse v1.4.0 h1:YGzvsTqCvbEZhL8zZu2AiA5nq805NZh75JNj4ajn1xc=
github.com/akamensky/argparse v1.4.0/go.mod h1:S5kwC7IuDcEr5VeXtGPRVZ5o/FdhcMlQz4IZQuw64xA=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255 h1:aIAyyj4XPrke9Tc/umbBCzP5SKX/CHf3dKrL/PhH2lo=
github.com/anandvarma/namegen v0.0.0-20230727084436-5197c6ea3255/go.mod h1:MFyILur9tG8PxaCXGZVr/2BOnHtRIgxYejYFZdWLxr0=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 h1:+6JSfuxZgmURoIlGdnYnY/FLRGWGagLyiBjt/VLtwi4=
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9/go.mod h1:6UxoDE+thWsISXK93pxaOuOfkcAfCvDbg0eAnFmxL5E=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cilium/ebpf v0.11.0 h1:V8gS/bTCCjX9uUnkUFUpPsksM8n1lXBAvHcpiFk1X2Y=
github.com/cilium/ebpf v0.11.0/go.mod h1:WE7CZAnqOL2RouJ4f1uyNhqr2P4CCvXFIqdRDUgWsVs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/jsimonetti/rtnetlink v1.3.5 h1:hVlNQNRlLDGZz31gBPicsG7Q53rnlsz1l1Ix/9XlpVA=
github.com/jsimonetti/rtnetlink v1.3.5/go.mod h1:0LFedyiTkebnd43tE4YAkWGIq9jQphow4CcwxaT2Y00=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM=
github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug=
golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+97u1FtcOUwPpIU6WL1Lkt7WpYjPA=
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=
google.golang.org/grpc v1.58.1/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@ -4,28 +4,14 @@ import (
"fmt" "fmt"
"net/http" "net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words" "github.com/tim-beatham/smegmesh/pkg/what8words"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock" // routesToApiRoute: convert the returned type to a JSON object
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route { func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes)) routes := make([]Route, len(meshNode.Routes))
@ -44,6 +30,7 @@ func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
return routes return routes
} }
// meshNodeToAPImeshNode: convert daemon node to a JSON node
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode { func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil { if meshNode.Routes == nil {
meshNode.Routes = make([]ctrlserver.MeshRoute, 0) meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
@ -65,9 +52,16 @@ func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNo
PublicKey: meshNode.PublicKey, PublicKey: meshNode.PublicKey,
Alias: alias, Alias: alias,
Services: meshNode.Services, Services: meshNode.Services,
Stats: SmegStats{
TotalTransmit: meshNode.Stats.TransmitBytes,
TotalReceived: meshNode.Stats.ReceivedBytes,
KeepAliveInterval: meshNode.Stats.PersistentKeepAliveInterval,
AllowedIps: meshNode.Stats.AllowedIPs,
},
} }
} }
// meshToAPIMesh: Convert daemon mesh network to a JSON mesh network
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh { func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh var smegMesh SmegMesh
smegMesh.MeshId = meshId smegMesh.MeshId = meshId
@ -80,6 +74,25 @@ func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) S
return smegMesh return smegMesh
} }
// putAlias: place an alias in the mesh
func (s *SmegServer) putAlias(meshId, alias string) error {
var reply string
return s.client.PutAlias(ipc.PutAliasArgs{
Alias: alias,
MeshId: meshId,
}, &reply)
}
func (s *SmegServer) putDescription(meshId, description string) error {
var reply string
return s.client.PutDescription(ipc.PutDescriptionArgs{
Description: description,
MeshId: meshId,
}, &reply)
}
// CreateMesh: creates a mesh network // CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) { func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest var createMesh CreateMeshRequest
@ -92,13 +105,21 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
return return
} }
fmt.Printf("%+v\n", createMesh)
ipcRequest := ipc.NewMeshArgs{ ipcRequest := ipc.NewMeshArgs{
WgArgs: ipc.WireGuardArgs{
WgPort: createMesh.WgPort, WgPort: createMesh.WgPort,
Role: createMesh.Role,
Endpoint: createMesh.PublicEndpoint,
AdvertiseRoutes: createMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: createMesh.AdvertiseDefaults,
},
} }
var reply string var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply) err := s.client.CreateMesh(&ipcRequest, &reply)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
@ -107,6 +128,14 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
return return
} }
if createMesh.Alias != "" {
s.putAlias(reply, createMesh.Alias)
}
if createMesh.Description != "" {
s.putDescription(reply, createMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{ c.JSON(http.StatusOK, &gin.H{
"meshid": reply, "meshid": reply,
}) })
@ -125,13 +154,19 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAddress: joinMesh.Bootstrap,
Port: joinMesh.WgPort, WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort,
Endpoint: joinMesh.PublicEndpoint,
Role: joinMesh.Role,
AdvertiseRoutes: joinMesh.AdvertiseRoutes,
AdvertiseDefaultRoute: joinMesh.AdvertiseDefaults,
},
} }
var reply string var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply) err := s.client.JoinMesh(ipcRequest, &reply)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
@ -140,6 +175,14 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
return return
} }
if joinMesh.Alias != "" {
s.putAlias(reply, joinMesh.Alias)
}
if joinMesh.Description != "" {
s.putDescription(reply, joinMesh.Description)
}
c.JSON(http.StatusOK, &gin.H{ c.JSON(http.StatusOK, &gin.H{
"status": "success", "status": "success",
}) })
@ -154,7 +197,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
getMeshReply := new(ipc.GetMeshReply) getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply) err := s.client.GetMesh(meshid, getMeshReply)
if err != nil { if err != nil {
c.JSON(http.StatusNotFound, c.JSON(http.StatusNotFound,
@ -169,10 +212,12 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
c.JSON(http.StatusOK, mesh) c.JSON(http.StatusOK, mesh)
} }
// GetMeshes: return all the mesh networks that the
// user is a part of
func (s *SmegServer) GetMeshes(c *gin.Context) { func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply) listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) err := s.client.ListMeshes(listMeshesReply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -185,7 +230,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
for _, mesh := range listMeshesReply.Meshes { for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply) getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply) err := s.client.GetMesh(mesh, getMeshReply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -199,13 +244,16 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
c.JSON(http.StatusOK, meshes) c.JSON(http.StatusOK, meshes)
} }
// Run: run the API server
func (s *SmegServer) Run(addr string) error { func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server") logging.Log.WriteInfof("Running API server")
return s.router.Run(addr) return s.router.Run(addr)
} }
// NewSmegServer: creates an instance of a new API server
// returns an error if something went wrong
func NewSmegServer(conf ApiServerConf) (ApiServer, error) { func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
return nil, err return nil, err
@ -229,9 +277,19 @@ func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
words: words, words: words,
} }
router.GET("/meshes", smegServer.GetMeshes) v1 := router.Group("/api/v1")
router.GET("/mesh/:meshid", smegServer.GetMesh) {
router.POST("/mesh/create", smegServer.CreateMesh) meshes := v1.Group("/meshes")
router.POST("/mesh/join", smegServer.JoinMesh) {
meshes.GET("/", smegServer.GetMeshes)
}
mesh := v1.Group("/mesh")
{
mesh.GET("/:meshid", smegServer.GetMesh)
mesh.POST("/create", smegServer.CreateMesh)
mesh.POST("/join", smegServer.JoinMesh)
}
}
return smegServer, nil return smegServer, nil
} }

View File

@ -1,37 +1,129 @@
package api package api
import (
"time"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/what8words"
)
// Route is an advertised route in the data store
type Route struct { type Route struct {
// Prefix is the advertised route prefix
Prefix string `json:"prefix"` Prefix string `json:"prefix"`
// Path is the hops the destination
Path []string `json:"path"` Path []string `json:"path"`
} }
type SmegNode struct { // SmegStats is the WireGuard stats that the underlying host
Alias string `json:"alias"` // has sent to the peer
WgHost string `json:"wgHost"` type SmegStats struct {
WgEndpoint string `json:"wgEndpoint"` // TotalTransmit number of bytes sent to the peer
Endpoint string `json:"endpoint"` TotalTransmit int64 `json:"totalTransmit"`
Timestamp int `json:"timestamp"` // TotalReceived number of bytes received from the peer
Description string `json:"description"` TotalReceived int64 `json:"totalReceived"`
PublicKey string `json:"publicKey"` // KeepAliveInterval WireGuard keepalive interval that is sent to the host
Routes []Route `json:"routes"` KeepAliveInterval time.Duration `json:"keepaliveInterval"`
Services map[string]string `json:"services"` // AllowsIps is the allowed path to the destination
AllowedIps []string `json:"allowedIps"`
} }
// SmegNode is a node in the mesh network
type SmegNode struct {
// Alias is the human readable name that the node is assocaited with
Alias string `json:"alias"`
// WgHost is the WireGuard IP address of the node. This is an IPv6
// address
WgHost string `json:"wgHost"`
// WgEndpoint is the physical endpoint of the host that packets
// are forwarded to
WgEndpoint string `json:"wgEndpoint"`
// Endpoint is the control plane endpoint of the host which
// grpc connections are to be sent along
Endpoint string `json:"endpoint"`
// Timestamp is the last time the signified it was alive.
// if the node is the leader this is evert heartBeatInterval
// otherwise this is the time the node joined the network
Timestamp int `json:"timestamp"`
// Description is the human readable description of the node
Description string `json:"description"`
// PublicKey is the WireGuard public key of the node
PublicKey string `json:"publicKey"`
// Routes is the routes that the node is advertising
Routes []Route `json:"routes"`
// Services is information about services that the node offers
Services map[string]string `json:"services"`
// Stats is the WireGuard stats of the node (if any)
Stats SmegStats `json:"stats"`
}
// SmegMesh encapsulates a single mesh in the API
type SmegMesh struct { type SmegMesh struct {
// MeshId is the mesh id of the network
MeshId string `json:"meshid"` MeshId string `json:"meshid"`
// Nodes is the nodes in the network keyed by their public
// key
Nodes map[string]SmegNode `json:"nodes"` Nodes map[string]SmegNode `json:"nodes"`
} }
// CreateMeshRequest encapsulates a request to create a mesh network
type CreateMeshRequest struct { type CreateMeshRequest struct {
// WgPort is the WireGuard to create the mesh in
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
} }
// JoinMeshRequests encapsulates a request to create a mesh network
type JoinMeshRequest struct { type JoinMeshRequest struct {
// WgPort is the WireGuard port to run the service on
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
// Bootstrap is a bootstrap node to use to join the network
Bootstrap string `json:"bootstrap" binding:"required"` Bootstrap string `json:"bootstrap" binding:"required"`
// MeshId is the ID of the mesh to join
MeshId string `json:"meshid" binding:"required"` MeshId string `json:"meshid" binding:"required"`
// Role is the role to take on in the mesh
Role string `json:"role" binding:"required,eq=client|eq=peer"`
// AdvertiseRoutes: advertise thi mesh to other meshes
AdvertiseRoutes bool `json:"advertiseRoutes"`
// AdvertiseDefaults: advertise an exit point
AdvertiseDefaults bool `json:"advertiseDefaults"`
// Alias: alias of the node in the mesh
Alias string `json:"alias"`
// Description: description of the node in the mesh
Description string `json:"description"`
// PublicEndpoint: an alternative public endpoint to advertise
PublicEndpoint string `json:"publicEndpoint"`
} }
// ApiServerConf configuration to instantiate the API server
type ApiServerConf struct { type ApiServerConf struct {
// WordsFile to use to map IP to words
WordsFile string WordsFile string
} }
// SmegSever is the GIN api server that runs the service
type SmegServer struct {
// gin router to use
router *gin.Engine
// client to invoke operations
client *ipc.SmegmeshIpc
// what8words to use to convert IP to an alias
words *what8words.What8Words
}
// ApiSever absrtacts the API server
type ApiServer interface {
Run(addr string) error
}

View File

@ -1,3 +1,5 @@
// automerge: package is depracated and unused. Please refer to crdt
// for crdt operations in the mesh
package automerge package automerge
import ( import (
@ -9,26 +11,36 @@ import (
"time" "time"
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// CrdtMeshManager manages nodes in the crdt mesh // CrdtMeshManager manage the CRDT datastore
type CrdtMeshManager struct { type CrdtMeshManager struct {
// MeshID of the mesh the datastore represents
MeshId string MeshId string
// IfName: corresponding ifName
IfName string IfName string
// Client: corresponding wireguard control client
Client *wgctrl.Client Client *wgctrl.Client
// doc: autommerge document
doc *automerge.Doc doc *automerge.Doc
// LastHash: last hash that the changes were made to
LastHash automerge.ChangeHash LastHash automerge.ChangeHash
// conf: WireGuard configuration
conf *conf.WgConfiguration conf *conf.WgConfiguration
// cache: stored cache of the list automerge document
// so that the store does not have to be repopulated each time
cache *MeshCrdt cache *MeshCrdt
// lastCachehash: hash of when the document was last changed
lastCacheHash automerge.ChangeHash lastCacheHash automerge.ChangeHash
} }
// AddNode as a node to the datastore
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNodeCrdt) crdt, ok := node.(*MeshNodeCrdt)
@ -40,9 +52,14 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt.Services = make(map[string]string) crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt) err := c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
if err != nil {
logging.Log.WriteInfof("error")
}
} }
// isPeer: returns true if the given node has type peer
func (c *CrdtMeshManager) isPeer(nodeId string) bool { func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId) node, err := c.doc.Path("nodes").Map().Get(nodeId)
@ -60,7 +77,8 @@ func (c *CrdtMeshManager) isPeer(nodeId string) bool {
} }
// isAlive: checks that the node's configuration has been updated // isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time // since the rquired keep alive time. Depracated no longer works
// due to changes in approach
func (c *CrdtMeshManager) isAlive(nodeId string) bool { func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId) node, err := c.doc.Path("nodes").Map().Get(nodeId)
@ -74,10 +92,11 @@ func (c *CrdtMeshManager) isAlive(nodeId string) bool {
return false return false
} }
return true
// return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime) // return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
return true
} }
// GetPeers: get all the peers in the mesh
func (c *CrdtMeshManager) GetPeers() []string { func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys() keys, _ := c.doc.Path("nodes").Map().Keys()
@ -88,7 +107,7 @@ func (c *CrdtMeshManager) GetPeers() []string {
return keys return keys
} }
// GetMesh(): Converts the document into a struct // GetMesh: Converts the document into a mesh network
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
changes, err := c.doc.Changes(c.lastCacheHash) changes, err := c.doc.Changes(c.lastCacheHash)
@ -110,7 +129,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return c.cache, nil return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId: returns the meshid of the mesh
func (c *CrdtMeshManager) GetMeshId() string { func (c *CrdtMeshManager) GetMeshId() string {
return c.MeshId return c.MeshId
} }
@ -131,6 +150,8 @@ func (c *CrdtMeshManager) Load(bytes []byte) error {
return nil return nil
} }
// NewCrdtNodeManagerParams: params to instantiate a new automerge
// datastore
type NewCrdtNodeMangerParams struct { type NewCrdtNodeMangerParams struct {
MeshId string MeshId string
DevName string DevName string
@ -139,7 +160,7 @@ type NewCrdtNodeMangerParams struct {
Client *wgctrl.Client Client *wgctrl.Client
} }
// NewCrdtNodeManager: Create a new crdt node manager // NewCrdtNodeManager: Create a new automerge crdt data store
func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) { func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, error) {
var manager CrdtMeshManager var manager CrdtMeshManager
manager.MeshId = params.MeshId manager.MeshId = params.MeshId
@ -151,17 +172,18 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
return &manager, nil return &manager, nil
} }
// NodeExists: returns true if the node exists. Returns false // NodeExists: returns true if the node exists other returns false
func (m *CrdtMeshManager) NodeExists(key string) bool { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key) node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil return node.Kind() == automerge.KindMap && err == nil
} }
// GetNode: gets a node from the mesh network.
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap { if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type") return nil, fmt.Errorf("getnode: node is not a map")
} }
if err != nil { if err != nil {
@ -177,10 +199,12 @@ func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
return meshNode, nil return meshNode, nil
} }
// Length: returns the number of nodes in the store
func (m *CrdtMeshManager) Length() int { func (m *CrdtMeshManager) Length() int {
return m.doc.Path("nodes").Map().Len() return m.doc.Path("nodes").Map().Len()
} }
// GetDevice: get the underlying WireGuard device
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) { func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName) dev, err := m.Client.Device(m.IfName)
@ -191,7 +215,7 @@ func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
return dev, nil return dev, nil
} }
// HasChanges returns true if we have changes since the last time we synced // HasChanges: returns true if there are changes since last time synchronised
func (m *CrdtMeshManager) HasChanges() bool { func (m *CrdtMeshManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash) changes, err := m.doc.Changes(m.LastHash)
@ -205,6 +229,7 @@ func (m *CrdtMeshManager) HasChanges() bool {
return len(changes) > 0 return len(changes) > 0
} }
// SaveChanges: save changes to the datastore
func (m *CrdtMeshManager) SaveChanges() { func (m *CrdtMeshManager) SaveChanges() {
hashes := m.doc.Heads() hashes := m.doc.Heads()
hash := hashes[len(hashes)-1] hash := hashes[len(hashes)-1]
@ -213,6 +238,7 @@ func (m *CrdtMeshManager) SaveChanges() {
m.LastHash = hash m.LastHash = hash
} }
// UpdateTimeStamp: updates the timestamp of the document
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error { func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -233,6 +259,7 @@ func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
return err return err
} }
// SetDescription: set the description of the given node
func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error { func (m *CrdtMeshManager) SetDescription(nodeId string, description string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -253,6 +280,7 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err return err
} }
// SetAlias: set the alias of the given node
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error { func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -273,6 +301,7 @@ func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
return err return err
} }
// AddService: add a service to the given node
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error { func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -294,6 +323,7 @@ func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
return err return err
} }
// RemoveService: remove a service from a node
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error { func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId) node, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -374,6 +404,7 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
return nil return nil
} }
// getRoutes: get the routes that the given node is directly advertising
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) { func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -400,6 +431,8 @@ func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
return lib.MapValues(routes), err return lib.MapValues(routes), err
} }
// GetRoutes: get all the routes that the node can see. The routes that the node
// can say may not be direct but cann also be indirect
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) { func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode) node, err := m.GetNode(targetNode)
@ -443,12 +476,13 @@ func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, e
return routes, nil return routes, nil
} }
// RemoveNode: removes a node from the datastore
func (m *CrdtMeshManager) RemoveNode(nodeId string) error { func (m *CrdtMeshManager) RemoveNode(nodeId string) error {
err := m.doc.Path("nodes").Map().Delete(nodeId) err := m.doc.Path("nodes").Map().Delete(nodeId)
return err return err
} }
// DeleteRoutes deletes the specified routes // RemoveRoutes: withdraw all the routes the nodeID is advertising
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error { func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -482,30 +516,37 @@ func (m *CrdtMeshManager) GetConfiguration() *conf.WgConfiguration {
func (m *CrdtMeshManager) Mark(nodeId string) { func (m *CrdtMeshManager) Mark(nodeId string) {
} }
// GetSyncer: get the bi-directionally syncer to synchronise the document
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer { func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m) return NewAutomergeSync(m)
} }
// Prune: prune all dead nodes
func (m *CrdtMeshManager) Prune() error { func (m *CrdtMeshManager) Prune() error {
return nil return nil
} }
// Compare: compare two mesh node for equality
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int { func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey) return strings.Compare(m1.PublicKey, m2.PublicKey)
} }
// GetHostEndpoint: get the ctrl endpoint of the host
func (m *MeshNodeCrdt) GetHostEndpoint() string { func (m *MeshNodeCrdt) GetHostEndpoint() string {
return m.HostEndpoint return m.HostEndpoint
} }
// GetPublicKey: get the public key of the node
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) { func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(m.PublicKey) return wgtypes.ParseKey(m.PublicKey)
} }
// GetWgEndpoint: get the outer WireGuard endpoint
func (m *MeshNodeCrdt) GetWgEndpoint() string { func (m *MeshNodeCrdt) GetWgEndpoint() string {
return m.WgEndpoint return m.WgEndpoint
} }
// GetWgHost: get the WireGuard IP address of the host
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet { func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost) _, ipnet, err := net.ParseCIDR(m.WgHost)
@ -516,10 +557,12 @@ func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
return ipnet return ipnet
} }
// GetTimeStamp: get timestamp if when the node was last updated
func (m *MeshNodeCrdt) GetTimeStamp() int64 { func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp return m.Timestamp
} }
// GetRoutes: get all the routes advertised by the node
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route { func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route { return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{ return &Route{
@ -529,10 +572,12 @@ func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
}) })
} }
// GetDescription: get the description of the node
func (m *MeshNodeCrdt) GetDescription() string { func (m *MeshNodeCrdt) GetDescription() string {
return m.Description return m.Description
} }
// GetIdentifier: get the iderntifier section of the ipv6 address
func (m *MeshNodeCrdt) GetIdentifier() string { func (m *MeshNodeCrdt) GetIdentifier() string {
ipv6 := m.WgHost[:len(m.WgHost)-4] ipv6 := m.WgHost[:len(m.WgHost)-4]
@ -541,10 +586,12 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
// GetAlias: get the alias of the node
func (m *MeshNodeCrdt) GetAlias() string { func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias return m.Alias
} }
// GetServices: get all the services the node is advertising
func (m *MeshNodeCrdt) GetServices() map[string]string { func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string) services := make(map[string]string)
@ -561,6 +608,7 @@ func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type) return conf.NodeType(n.Type)
} }
// GetNodes: get all the nodes in the network
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -582,15 +630,18 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
return nodes return nodes
} }
// GetDestination: get destination of the route
func (r *Route) GetDestination() *net.IPNet { func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination) _, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet return ipnet
} }
// GetHopCount: get the number of hops to the destination
func (r *Route) GetHopCount() int { func (r *Route) GetHopCount() int {
return len(r.Path) return len(r.Path)
} }
// GetPath: get the total path which includes the number of hops
func (r *Route) GetPath() []string { func (r *Route) GetPath() []string {
return r.Path return r.Path
} }

View File

@ -1,15 +1,24 @@
// automerge: automerge is a CRDT library. Defines a CRDT
// datastore and methods to resolve conflicts
package automerge package automerge
import ( import (
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// AutomergeSync: defines a synchroniser to bi-directionally synchronise the
// two states
type AutomergeSync struct { type AutomergeSync struct {
// state: the automerge sync state to use
state *automerge.SyncState state *automerge.SyncState
// manager: the corresponding data store that we are merging
manager *CrdtMeshManager manager *CrdtMeshManager
} }
// GenerateMessage: geenrate a new automerge message to synchronise
// returns a byte of the message and a boolean of whether or not there
// are more messages in the sequence
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) { func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
msg, valid := a.state.GenerateMessage() msg, valid := a.state.GenerateMessage()
@ -20,6 +29,8 @@ func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
return msg.Bytes(), true return msg.Bytes(), true
} }
// RecvMessage: receive an automerge message to merge in the datastore
// returns an error if unsuccessful
func (a *AutomergeSync) RecvMessage(msg []byte) error { func (a *AutomergeSync) RecvMessage(msg []byte) error {
_, err := a.state.ReceiveMessage(msg) _, err := a.state.ReceiveMessage(msg)
@ -30,11 +41,13 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
return nil return nil
} }
// Complete: complete the synchronisation process
func (a *AutomergeSync) Complete() { func (a *AutomergeSync) Complete() {
logging.Log.WriteInfof("Sync Completed") logging.Log.WriteInfof("sync completed")
a.manager.SaveChanges() a.manager.SaveChanges()
} }
// NewAutomergeSync: instantiates a new automerge syncer
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync { func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
return &AutomergeSync{ return &AutomergeSync{
state: automerge.NewSyncState(manager.doc), state: automerge.NewSyncState(manager.doc),

View File

@ -1,14 +1,14 @@
package automerge package automerge
import ( import (
"slices" "net"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -22,7 +22,7 @@ func setUpTests() *TestParams {
DevName: "wg0", DevName: "wg0",
Port: 5000, Port: 5000,
Client: nil, Client: nil,
Conf: conf.DaemonConfiguration{}, Conf: &conf.WgConfiguration{},
}) })
return &TestParams{ return &TestParams{
@ -31,22 +31,26 @@ func setUpTests() *TestParams {
} }
func getTestNode() mesh.MeshNode { func getTestNode() mesh.MeshNode {
pubKey, _ := wgtypes.GeneratePrivateKey()
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8080", HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906", WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128", WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: "AAAAAAAAAAAA", PublicKey: pubKey.String(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Description: "A node that we are adding", Description: "A node that we are adding",
} }
} }
func getTestNode2() mesh.MeshNode { func getTestNode2() mesh.MeshNode {
pubKey, _ := wgtypes.GeneratePrivateKey()
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: "public-endpoint:8081", HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907", WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128", WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d219/128",
PublicKey: "BBBBBBBBB", PublicKey: pubKey.String(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Description: "A node that we are adding", Description: "A node that we are adding",
} }
@ -54,9 +58,11 @@ func getTestNode2() mesh.MeshNode {
func TestAddNodeNodeExists(t *testing.T) { func TestAddNodeNodeExists(t *testing.T) {
testParams := setUpTests() testParams := setUpTests()
testParams.manager.AddNode(getTestNode()) node := getTestNode()
testParams.manager.AddNode(node)
node, err := testParams.manager.GetNode("public-endpoint:8080") pubKey, _ := node.GetPublicKey()
node, err := testParams.manager.GetNode(pubKey.String())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -70,25 +76,27 @@ func TestAddNodeNodeExists(t *testing.T) {
func TestAddNodeAddRoute(t *testing.T) { func TestAddNodeAddRoute(t *testing.T) {
testParams := setUpTests() testParams := setUpTests()
testNode := getTestNode() testNode := getTestNode()
testParams.manager.AddNode(testNode) pubKey, _ := testNode.GetPublicKey()
testParams.manager.AddRoutes(testNode.GetHostEndpoint(), "fd:1c64:1d00::/48")
updatedNode, err := testParams.manager.GetNode(testNode.GetHostEndpoint()) _, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48")
testParams.manager.AddNode(testNode)
testParams.manager.AddRoutes(pubKey.String(), &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
updatedNode, err := testParams.manager.GetNode(pubKey.String())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if updatedNode == nil { if updatedNode == nil {
t.Fatalf(`Node does not exist in the mesh`) t.Fatalf(`node does not exist in the mesh`)
} }
routes := updatedNode.GetRoutes() routes := updatedNode.GetRoutes()
if !slices.Contains(routes, "fd:1c64:1d00::/48") {
t.Fatal("Route node not added")
}
if len(routes) != 1 { if len(routes) != 1 {
t.Fatal(`Route length mismatch`) t.Fatal(`Route length mismatch`)
} }
@ -253,7 +261,9 @@ func TestUpdateTimeStampNodeExists(t *testing.T) {
node := getTestNode() node := getTestNode()
testParams.manager.AddNode(node) testParams.manager.AddNode(node)
err := testParams.manager.UpdateTimeStamp(node.GetHostEndpoint()) pubKey, _ := node.GetPublicKey()
err := testParams.manager.UpdateTimeStamp(pubKey.String())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -282,7 +292,12 @@ func TestSetDescriptionNodeExists(t *testing.T) {
func TestAddRoutesNodeDoesNotExist(t *testing.T) { func TestAddRoutesNodeDoesNotExist(t *testing.T) {
testParams := setUpTests() testParams := setUpTests()
err := testParams.manager.AddRoutes("AAAAA", "fd:1c64:1d00::/48") _, destination, _ := net.ParseCIDR("fd:1c64:1d00::/48")
err := testParams.manager.AddRoutes("AAAAA", &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
if err == nil { if err == nil {
t.Error(err) t.Error(err)
@ -293,16 +308,11 @@ func TestCompareComparesByPublicKey(t *testing.T) {
node := getTestNode().(*MeshNodeCrdt) node := getTestNode().(*MeshNodeCrdt)
node2 := getTestNode2().(*MeshNodeCrdt) node2 := getTestNode2().(*MeshNodeCrdt)
if node.Compare(node2) != -1 { pubKey1, _ := node.GetPublicKey()
t.Fatalf(`node is alphabetically before node2`) pubKey2, _ := node2.GetPublicKey()
}
if node2.Compare(node) != 1 { if node.Compare(node2) != strings.Compare(pubKey1.String(), pubKey2.String()) {
t.Fatalf(`node is alphabetical;y before node2`) t.Fatalf(`compare failed`)
}
if node.Compare(node) != 0 {
t.Fatalf(`node is equal to node`)
} }
} }

View File

@ -3,13 +3,16 @@ package automerge
import ( import (
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// CrdtProviderFactory: abstracts the instantiation of an automerge
// datastore
type CrdtProviderFactory struct{} type CrdtProviderFactory struct{}
// CreateMesh: create a new mesh datastore
func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return NewCrdtNodeManager(&NewCrdtNodeMangerParams{ return NewCrdtNodeManager(&NewCrdtNodeMangerParams{
MeshId: params.MeshId, MeshId: params.MeshId,
@ -19,16 +22,17 @@ func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
}) })
} }
// MeshNodeFactory: abstracts the instnatiation of a node
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.DaemonConfiguration Config conf.DaemonConfiguration
} }
// Build builds the mesh node that represents the host machine to add // Build: builds the mesh node that represents the host machine to add
// to the mesh // to the mesh
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort) grpcEndpoint := fmt.Sprintf("%s:%d", hostName, f.Config.GrpcPort)
if *params.MeshConfig.Role == conf.CLIENT_ROLE { if *params.MeshConfig.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-" grpcEndpoint = "-"
@ -48,7 +52,7 @@ func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNod
} }
} }
// getAddress returns the routable address of the machine. // getAddress: returns the routable address of the machine.
func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string { func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string {
var hostName string = "" var hostName string = ""
@ -59,7 +63,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else { } else {
ipFunc := lib.GetPublicIP ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP ipFunc = lib.GetOutboundIP
} }

View File

@ -6,10 +6,12 @@ import (
"strings" "strings"
) )
// CmdRunner: run cmd commands when instantiating a network
type CmdRunner interface { type CmdRunner interface {
RunCommands(commands ...string) error RunCommands(commands ...string) error
} }
// UnixCmdRunner: Run UNIX commands
type UnixCmdRunner struct{} type UnixCmdRunner struct{}
// RunCommand: runs the unix command. It splits the command into fields // RunCommand: runs the unix command. It splits the command into fields
@ -20,6 +22,7 @@ func RunCommand(cmd string) error {
return c.Run() return c.Run()
} }
// RunCommands: run a series of commands
func (l *UnixCmdRunner) RunCommands(commands ...string) error { func (l *UnixCmdRunner) RunCommands(commands ...string) error {
for _, cmd := range commands { for _, cmd := range commands {
err := RunCommand(cmd) err := RunCommand(cmd)

View File

@ -8,14 +8,7 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type WgMeshConfigurationError struct { // NodeType types of the node either peer or client
msg string
}
func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type NodeType string type NodeType string
const ( const (
@ -23,11 +16,23 @@ const (
CLIENT_ROLE NodeType = "client" CLIENT_ROLE NodeType = "client"
) )
// IPDiscovery: what IPDiscovery service to use
type IPDiscovery string type IPDiscovery string
const ( const (
PUBLIC_IP_DISCOVERY = "public" // Public IP use an IP service to discover your IP
DNS_IP_DISCOVERY = "dns" PUBLIC_IP_DISCOVERY IPDiscovery = "public"
// Outgonig: Use your labelled packet IP
OUTGOING_IP_DISCOVERY IPDiscovery = "outgoing"
)
// Loglevel: what log level to use either error info or warning
type LogLevel string
const (
ERROR LogLevel = "error"
WARNING LogLevel = "warning"
INFO LogLevel = "info"
) )
// WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can // WgConfiguration contains per-mesh WireGuard configuration. Contains poitner types only so we can
@ -35,19 +40,18 @@ const (
type WgConfiguration struct { type WgConfiguration struct {
// IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public // IPDIscovery: how to discover your IP if not specified. Use your outgoing IP or use a public
// service for IPDiscoverability // service for IPDiscoverability
IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=dns"` IPDiscovery *IPDiscovery `yaml:"ipDiscovery" validate:"required,eq=public|eq=outgoing"`
// AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes // AdvertiseRoutes: specifies whether the node can act as a router routing packets between meshes
AdvertiseRoutes *bool `yaml:"advertiseRoutes"` AdvertiseRoutes *bool `yaml:"advertiseRoute" validate:"required"`
// AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route // AdvertiseDefaultRoute: specifies whether or not this route should advertise a default route
// for all nodes to route their packets to // for all nodes to route their packets to
AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults"` AdvertiseDefaultRoute *bool `yaml:"advertiseDefaults" validate:"required"`
// Endpoint contains what value should be set as the public endpoint of this node // Endpoint contains what value should be set as the public endpoint of this node
Endpoint *string `yaml:"publicEndpoint"` Endpoint *string `yaml:"publicEndpoint"`
// Role specifies whether or not the user is globally accessible. // Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client. // If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"` Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers. // KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"` KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface // PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"` PreUp []string `yaml:"preUp"`
@ -61,11 +65,11 @@ type WgConfiguration struct {
type DaemonConfiguration struct { type DaemonConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath" validate:"required,file"` CertificatePath string `yaml:"certificatePath" validate:"required"`
// PrivateKeypath is the path to the clients private key in mTLS // PrivateKeypath is the path to the clients private key in mTLS
PrivateKeyPath string `yaml:"privateKeyPath" validate:"required,file"` PrivateKeyPath string `yaml:"privateKeyPath" validate:"required"`
// CaCeritifcatePath path to the certificate of the trust certificate authority // CaCeritifcatePath path to the certificate of the trust certificate authority
CaCertificatePath string `yaml:"caCertificatePath" validate:"required,file"` CaCertificatePath string `yaml:"caCertificatePath" validate:"required"`
// SkipCertVerification specify to skip certificate verification. Should only be used // SkipCertVerification specify to skip certificate verification. Should only be used
// in test environments // in test environments
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
@ -73,26 +77,28 @@ type DaemonConfiguration struct {
GrpcPort int `yaml:"gRPCPort" validate:"required"` GrpcPort int `yaml:"gRPCPort" validate:"required"`
// Timeout number of seconds without response that a node is considered unreachable by gRPC // Timeout number of seconds without response that a node is considered unreachable by gRPC
Timeout int `yaml:"timeout" validate:"required,gte=1"` Timeout int `yaml:"timeout" validate:"required,gte=1"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// SyncRate specifies how long the minimum time should be between synchronisation // SyncInterval specifies how long the minimum time should be between synchronisation
SyncRate int `yaml:"syncRate" validate:"required,gte=1"` SyncInterval int `yaml:"syncInterval" validate:"required,gte=1"`
// KeepAliveTime: number of seconds before the leader of the mesh sends an update to // PullInterval specifies the interval between checking for configuration changes
PullInterval int `yaml:"pullInterval" validate:"gte=0"`
// Heartbeat: number of seconds before the leader of the mesh sends an update to
// send to every member in the mesh // send to every member in the mesh
KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"` Heartbeat int `yaml:"heartbeatInterval" validate:"required,gte=1"`
// ClusterSize specifies how many neighbours you should synchronise with per round // ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" valdiate:"required,gt=0"` ClusterSize int `yaml:"clusterSize" validate:"gte=1"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance" valdiate:"required,gt=0"` InterClusterChance float64 `yaml:"interClusterChance" validate:"gt=0"`
// BranchRate specifies the number of nodes to synchronise with when a node has // Branch specifies the number of nodes to synchronise with when a node has
// new changes to send to the mesh // new changes to send to the mesh
BranchRate int `yaml:"branchRate" validate:"required,gte=1"` Branch int `yaml:"branch" validate:"required,gte=1"`
// InfectionCount: number of time to sync before an update can no longer be 'caught' // InfectionCount: number of time to sync before an update can no longer be 'caught'
InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"` InfectionCount int `yaml:"infectionCount" validate:"required,gte=1"`
// BaseConfiguration base WireGuard configuration to use, this is used when none is provided // BaseConfiguration base WireGuard configuration to use, this is used when none is provided
BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"` BaseConfiguration WgConfiguration `yaml:"baseConfiguration" validate:"required"`
// LogLevel specifies the log level to output, defaults is warning
LogLevel LogLevel `yaml:"logLevel" validate:"eq=info|eq=warning|eq=error"`
} }
// ValdiateMeshConfiguration: validates the mesh configuration // ValdiateMeshConfiguration: validates the mesh configuration
@ -120,32 +126,21 @@ func ValidateMeshConfiguration(conf *WgConfiguration) error {
} }
// ValidateDaemonConfiguration: validates the dameon configuration that is used. // ValidateDaemonConfiguration: validates the dameon configuration that is used.
func ValidateDaemonConfiguration(c *DaemonConfiguration) error { func ValidateDaemonConfiguration(conf *DaemonConfiguration) error {
if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
if conf.LogLevel == "" {
conf.LogLevel = WARNING
}
validate := validator.New(validator.WithRequiredStructEnabled()) validate := validator.New(validator.WithRequiredStructEnabled())
err := validate.Struct(c) err := validate.Struct(conf)
return err return err
} }
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration // ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) { func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration var conf DaemonConfiguration

View File

@ -1,13 +1,40 @@
package conf package conf
import "testing" import (
"testing"
)
func getExampleConfiguration() *DaemonConfiguration { func getExampleConfiguration() *DaemonConfiguration {
discovery := PUBLIC_IP_DISCOVERY
advertiseRoutes := false
advertiseDefaultRoute := false
endpoint := "abc.com:123"
nodeType := CLIENT_ROLE
keepAliveWg := 0
return &DaemonConfiguration{ return &DaemonConfiguration{
CertificatePath: "./cert/cert.pem", CertificatePath: "../../../cert/cert.pem",
PrivateKeyPath: "./cert/key.pem", PrivateKeyPath: "../../../cert/priv.pem",
CaCertificatePath: "./cert/ca.pems", CaCertificatePath: "../../../cert/cacert.pem",
SkipCertVerification: true, SkipCertVerification: true,
GrpcPort: 25,
Timeout: 5,
StubWg: false,
SyncInterval: 2,
Heartbeat: 2,
ClusterSize: 64,
InterClusterChance: 0.15,
Branch: 3,
PullInterval: 0,
InfectionCount: 2,
BaseConfiguration: WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Endpoint: &endpoint,
Role: &nodeType,
KeepAliveWg: &keepAliveWg,
},
} }
} }
@ -55,9 +82,152 @@ func TestConfigurationGrpcPortEmpty(t *testing.T) {
} }
} }
func TestIPDiscoveryNotSet(t *testing.T) {
conf := getExampleConfiguration()
ipDiscovery := IPDiscovery("djdsjdskd")
conf.BaseConfiguration.IPDiscovery = &ipDiscovery
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestAdvertiseRoutesNotSet(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.AdvertiseRoutes = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestAdvertiseDefaultRouteNotSet(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.AdvertiseDefaultRoute = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestKeepAliveWgNegative(t *testing.T) {
conf := getExampleConfiguration()
keepAliveWg := -1
conf.BaseConfiguration.KeepAliveWg = &keepAliveWg
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestRoleTypeNotValid(t *testing.T) {
conf := getExampleConfiguration()
role := NodeType("bruhhh")
conf.BaseConfiguration.Role = &role
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestRoleTypeNotSpecified(t *testing.T) {
conf := getExampleConfiguration()
conf.BaseConfiguration.Role = nil
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`invalid role type`)
}
}
func TestBranchRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.Branch = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestsyncTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.SyncInterval = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestKeepAliveTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.Heartbeat = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestClusterSizeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.ClusterSize = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestInterClusterChanceZero(t *testing.T) {
conf := getExampleConfiguration()
conf.InterClusterChance = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestInfectionCountOne(t *testing.T) {
conf := getExampleConfiguration()
conf.InfectionCount = 0
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPullTimeNegative(t *testing.T) {
conf := getExampleConfiguration()
conf.PullInterval = -1
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidConfiguration(t *testing.T) { func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err != nil { if err != nil {

View File

@ -7,25 +7,30 @@ import (
"slices" "slices"
) )
// ConnCluster splits nodes into clusters where nodes in a cluster communicate // ConnCluster: splits nodes into clusters where nodes in a cluster communicate
// frequently and nodes outside of a cluster communicate infrequently // frequently and nodes outside of a cluster communicate infrequently
type ConnCluster interface { type ConnCluster interface {
// Getneighbours: get neighbours of the cluster the node is in
GetNeighbours(global []string, selfId string) []string GetNeighbours(global []string, selfId string) []string
// GetInterCluster: get the cluster to communicate with
GetInterCluster(global []string, selfId string) string GetInterCluster(global []string, selfId string) string
} }
// ConnnClusterImpl: implementation of the connection cluster
type ConnClusterImpl struct { type ConnClusterImpl struct {
clusterSize int clusterSize int
} }
// perform binary search to attain a size of a group
func binarySearch(global []string, selfId string, groupSize int) (int, int) { func binarySearch(global []string, selfId string, groupSize int) (int, int) {
slices.Sort(global) slices.Sort(global)
lower := 0 lower := 0
higher := len(global) - 1 higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize { for (higher+1)-lower > groupSize {
mid := (lower + higher) / 2
if global[mid] < selfId { if global[mid] < selfId {
lower = mid + 1 lower = mid + 1
} else if global[mid] > selfId { } else if global[mid] > selfId {
@ -33,14 +38,12 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
} else { } else {
break break
} }
mid = (lower + higher) / 2
} }
return lower, int(math.Min(float64(lower+groupSize), float64(len(global)))) return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))
} }
// GetNeighbours return the neighbours 'nearest' to you. In this implementation the // GetNeighbours: return the neighbours 'nearest' to you. In this implementation the
// neighbours aren't actually the ones nearest to you but just the ones nearest // neighbours aren't actually the ones nearest to you but just the ones nearest
// to you alphabetically. Perform binary search to get the total group // to you alphabetically. Perform binary search to get the total group
func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string { func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string {
@ -51,19 +54,22 @@ func (i *ConnClusterImpl) GetNeighbours(global []string, selfId string) []string
return global[lower:higher] return global[lower:higher]
} }
// GetInterCluster get nodes not in your cluster. Every round there is a given chance // GetInterCluster: get nodes not in your cluster. Every round there is a given chance
// you will communicate with a random node that is not in your cluster. // you will communicate with a random node that is not in your cluster.
func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string { func (i *ConnClusterImpl) GetInterCluster(global []string, selfId string) string {
// Doesn't matter if not in it. Get index of where the node 'should' be // Doesn't matter if not in it. Get index of where the node 'should' be
slices.Sort(global)
index, _ := binarySearch(global, selfId, 1) index, _ := binarySearch(global, selfId, 1)
numClusters := math.Ceil(float64(len(global)) / float64(i.clusterSize))
randomCluster := rand.Intn(int(numClusters)-1) + 1 randomCluster := rand.Intn(2) + 1
neighbourIndex := (index + randomCluster) % len(global) // cluster is considered a heap
neighbourIndex := (2*index + (randomCluster * i.clusterSize)) % len(global)
return global[neighbourIndex] return global[neighbourIndex]
} }
// NewConnCluster: instantiate a new connection cluster of a given group size.
func NewConnCluster(clusterSize int) (ConnCluster, error) { func NewConnCluster(clusterSize int) (ConnCluster, error) {
log2Cluster := math.Log2(float64(clusterSize)) log2Cluster := math.Log2(float64(clusterSize))

View File

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
@ -18,6 +18,7 @@ type PeerConnection interface {
GetClient() (*grpc.ClientConn, error) GetClient() (*grpc.ClientConn, error)
} }
// PeerConenctionFactory: create a new connection to a peer
type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error) type PeerConnectionFactory = func(clientConfig *tls.Config, server string) (PeerConnection, error)
// WgCtrlConnection implements PeerConnection. // WgCtrlConnection implements PeerConnection.

View File

@ -7,7 +7,7 @@ import (
"os" "os"
"sync" "sync"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// ConnectionManager defines an interface for maintaining peer connections // ConnectionManager defines an interface for maintaining peer connections
@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error // If the endpoint does not exist then add the connection. Returns an error
// if something went wrong // if something went wrong
GetConnection(endPoint string) (PeerConnection, error) GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne // HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise. // endpoint or false otherwise.
HasConnection(endPoint string) bool HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone // Goes through all the connections and closes eachone
Close() error Close() error
} }
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection // clientConnections maps an endpoint to a connection
conLoc sync.RWMutex conLoc sync.RWMutex
clientConnections map[string]PeerConnection clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config clientConfig *tls.Config
connFactory PeerConnectionFactory connFactory PeerConnectionFactory
} }
@ -61,16 +62,8 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err return nil, err
} }
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
if !params.SkipCertVerification {
if params.CaCert == "" { if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified") return nil, errors.New("CA Cert is not specified")
} }
@ -81,17 +74,13 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err return nil, err
} }
certPool.AppendCertsFromPEM(caCert) if ok := certPool.AppendCertsFromPEM(caCert); !ok {
} return nil, errors.New("could not parse PEM")
serverConfig := &tls.Config{
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert},
} }
clientConfig := &tls.Config{ clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification, InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool, RootCAs: certPool,
} }
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{ connMgr := ConnectionManagerImpl{
sync.RWMutex{}, sync.RWMutex{},
connections, connections,
serverConfig,
clientConfig, clientConfig,
params.ConnFactory, params.ConnFactory,
} }
@ -150,6 +138,21 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists return exists
} }
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
connection, ok := m.clientConnections[endPoint]
var err error
if ok {
err = connection.Close()
delete(m.clientConnections, endPoint)
}
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error { func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections { for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) { func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams() params := getConnectionManagerParams()
params.CaCert = "" params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true params.SkipCertVerification = true
_, err := NewConnectionManager(params) _, err := NewConnectionManager(params)
if err != nil { if err == nil {
t.Fatal(`an error should not be thrown`) t.Fatalf(`an error should be thrown`)
} }
} }

View File

@ -2,22 +2,23 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors"
"fmt" "fmt"
"net" "net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
// ConnectionServer manages gRPC server peer connections // ConnectionServer manages gRPC server peer connections
type ConnectionServer struct { type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server // server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes server *grpc.Server
// the ctrl service to manage node // the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes // the sync service to synchronise nodes
@ -48,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert serverAuth = tls.RequireAnyClientCert
} }
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{ serverConfig := &tls.Config{
ClientAuth: serverAuth, ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
} }
server := grpc.NewServer( server := grpc.NewServer(
@ -61,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider syncProvider := params.SyncProvider
connServer := ConnectionServer{ connServer := ConnectionServer{
serverConfig: serverConfig,
server: server, server: server,
ctrlProvider: ctrlProvider, ctrlProvider: ctrlProvider,
syncProvider: syncProvider, syncProvider: syncProvider,
@ -74,7 +91,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong. // Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error { func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort)) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil return mock, nil
} }
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) { func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint] endpoint, ok := s.Endpoints[endPoint]

View File

@ -1,84 +0,0 @@
package conn
import (
"errors"
"slices"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// ConnectionWindow maintains a sliding window of connections between users
type ConnectionWindow interface {
// GetWindow is a list of connections to choose from
GetWindow() []string
// SlideConnection removes a node from the window and adds a random node
// not already in the window. connList represents the list of possible
// connections to choose from
SlideConnection(connList []string) error
// PushConneciton is used when connection list less than window size.
PutConnection(conn []string) error
// IsFull returns true if the window is full. In which case we must slide the window
IsFull() bool
}
type ConnectionWindowImpl struct {
window []string
windowSize int
}
// GetWindow gets the current list of active connections in
// the window
func (c *ConnectionWindowImpl) GetWindow() []string {
return c.window
}
// SlideConnection slides the connection window by one shuffling items
// in the windows
func (c *ConnectionWindowImpl) SlideConnection(connList []string) error {
// If the number of peer connections is less than the length of the window
// then exit early. Can't slide the window it should contain all nodes!
if len(c.window) < c.windowSize {
return nil
}
filter := func(node string) bool {
return !slices.Contains(c.window, node)
}
pool := lib.Filter(connList, filter)
newNode := lib.RandomSubsetOfLength(pool, 1)
if len(newNode) == 0 {
return errors.New("could not slide window")
}
for i := len(c.window) - 1; i >= 1; i-- {
c.window[i] = c.window[i-1]
}
c.window[0] = newNode[0]
return nil
}
// PutConnection put random connections in the connection
func (c *ConnectionWindowImpl) PutConnection(connList []string) error {
if len(c.window) >= c.windowSize {
return errors.New("cannot place connection. Window full need to slide")
}
c.window = lib.RandomSubsetOfLength(connList, c.windowSize)
return nil
}
func (c *ConnectionWindowImpl) IsFull() bool {
return len(c.window) >= c.windowSize
}
func NewConnectionWindow(windowLength int) ConnectionWindow {
window := &ConnectionWindowImpl{
window: make([]string, 0),
windowSize: windowLength,
}
return window
}

264
pkg/cplane/requester.go Normal file
View File

@ -0,0 +1,264 @@
package robin
import (
"context"
"errors"
"fmt"
"slices"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
)
// IpcHandler: represents a handler for ipc calls
type IpcHandler struct {
Server ctrlserver.CtrlServer
}
// getOverrideConfiguration: override any specific WireGuard configuration
func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
overrideConf := conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
if args.KeepAliveWg != 0 {
keepAliveWg := args.KeepAliveWg
overrideConf.KeepAliveWg = &keepAliveWg
}
overrideConf.AdvertiseRoutes = &args.AdvertiseRoutes
overrideConf.AdvertiseDefaultRoute = &args.AdvertiseDefaultRoute
return overrideConf
}
// CreateMesh: create a new mesh network
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs)
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort,
Conf: &overrideConf,
})
if err != nil {
return errors.New("could not create mesh")
}
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId,
WgPort: args.WgArgs.WgPort,
Endpoint: args.WgArgs.Endpoint,
})
if err != nil {
return errors.New("could not create mesh: " + err.Error())
}
*reply = meshId
return err
}
// ListMeshes: list mesh networks
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0
for meshId := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId
i++
}
slices.Sort(meshNames)
*reply = ipc.ListMeshReply{Meshes: meshNames}
return nil
}
// JoinMesh: join a mesh network
func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs)
if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil {
return fmt.Errorf("user is already a part of the mesh")
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress)
if err != nil {
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
client, err := peerConnection.GetClient()
if err != nil {
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
c := rpc.NewMeshCtrlServerClient(client)
if err != nil {
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
if err != nil {
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
WgPort: args.WgArgs.WgPort,
MeshBytes: meshReply.Mesh,
Conf: &overrideConf,
})
if err != nil {
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId,
WgPort: args.WgArgs.WgPort,
Endpoint: args.WgArgs.Endpoint,
})
if err != nil {
return fmt.Errorf("could not join mesh %s", args.MeshId)
}
*reply = fmt.Sprintf("Successfully Joined: %s", args.MeshId)
return nil
}
// LeaveMesh: leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
}
return err
}
// GetMesh: get a mesh network at the given meshid
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
if theMesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
for _, node := range meshSnapshot.GetNodes() {
node := ctrlserver.NewCtrlNode(theMesh, node)
nodes[i] = *node
i += 1
}
*reply = ipc.GetMeshReply{Nodes: nodes}
return nil
}
// Query: perform a jmespath query
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
if err != nil {
return err
}
*reply = string(queryResponse)
return nil
}
// PutDescription: change your description in the mesh
func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description)
if err != nil {
return err
}
*reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId)
return nil
}
// PutAlias: put your aliasin the mesh
func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error {
if args.Alias == "" {
return fmt.Errorf("alias not provided")
}
err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias)
if err != nil {
return fmt.Errorf("could not set alias: %s", args.Alias)
}
*reply = fmt.Sprintf("Set alias to %s", args.Alias)
return nil
}
// PutService: place a service in the mesh
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set service %s in %s to %s", service.Service, service.MeshId, service.Value)
return nil
}
// DeleteService: withtract a service in the mesh
func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service)
if err != nil {
return err
}
*reply = fmt.Sprintf("Removed service %s from %s", service.Service, service.MeshId)
return nil
}
// RobinIpcParams: parameters required to construct a new mesh network
type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer
}
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {
return IpcHandler{
Server: ipcParams.CtrlServer,
}
}

View File

@ -3,9 +3,10 @@ package robin
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/smegmesh/pkg/mesh"
) )
func getRequester() *IpcHandler { func getRequester() *IpcHandler {
@ -17,9 +18,11 @@ func TestCreateMeshRepliesMeshId(t *testing.T) {
requester := getRequester() requester := getRequester()
err := requester.CreateMesh(&ipc.NewMeshArgs{ err := requester.CreateMesh(&ipc.NewMeshArgs{
IfName: "wg0", WgArgs: ipc.WireGuardArgs{
WgPort: 5000, WgPort: 500,
Endpoint: "abc.com", Endpoint: "abc.com:1234",
Role: "peer",
},
}, &reply) }, &reply)
if err != nil { if err != nil {
@ -52,9 +55,8 @@ func TestListMeshesMeshesNotEmpty(t *testing.T) {
requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ requester.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: "tim123", MeshId: "tim123",
DevName: "wg0",
WgPort: 5000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
Conf: &conf.WgConfiguration{},
}) })
err := requester.ListMeshes("", &reply) err := requester.ListMeshes("", &reply)

View File

@ -4,15 +4,17 @@ import (
"context" "context"
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// WgRpc: represents a WireGuard rpc call
type WgRpc struct { type WgRpc struct {
rpc.UnimplementedMeshCtrlServerServer rpc.UnimplementedMeshCtrlServerServer
Server *ctrlserver.MeshCtrlServer Server *ctrlserver.MeshCtrlServer
} }
// GetMesh: serialise the mesh network into bytes
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) { func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId) mesh := m.Server.MeshManager.GetMesh(request.MeshId)

View File

@ -9,16 +9,19 @@ import (
"strings" "strings"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// Route: represents a route within the data store
type Route struct { type Route struct {
// Destination the route is advertising
Destination string Destination string
// Path to the destination
Path []string Path []string
} }
@ -158,8 +161,8 @@ type TwoPhaseStoreMeshManager struct {
IfName string IfName string
Client *wgctrl.Client Client *wgctrl.Client
LastClock uint64 LastClock uint64
conf *conf.WgConfiguration Conf *conf.WgConfiguration
daemonConf *conf.DaemonConfiguration DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode] store *TwoPhaseMap[string, MeshNode]
} }
@ -204,7 +207,6 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
var buf bytes.Buffer var buf bytes.Buffer
enc := gob.NewEncoder(&buf) enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot) err := enc.Encode(*snapshot)
if err != nil { if err != nil {
@ -249,7 +251,8 @@ func (m *TwoPhaseStoreMeshManager) SaveChanges() {
m.LastClock = clockValue m.LastClock = clockValue
} }
// UpdateTimeStamp: update the timestamp of the given node // UpdateTimeStamp: update the timestamp of the given node, causes a configuration refresh if the node
// is the leader causing all nodes to update their vector clocks
func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error { func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -265,7 +268,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
peerToUpdate := peers[0] peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) { if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.Heartbeat) {
m.store.Mark(peerToUpdate) m.store.Mark(peerToUpdate)
if len(peers) < 2 { if len(peers) < 2 {
@ -313,6 +316,8 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
} }
} }
// Only add nodes on changes. Otherwise the node will advertise new
// information whenever they get new routes
if changes { if changes {
m.store.Put(nodeId, node) m.store.Put(nodeId, node)
} }
@ -320,7 +325,7 @@ func (m *TwoPhaseStoreMeshManager) AddRoutes(nodeId string, routes ...mesh.Route
return nil return nil
} }
// DeleteRoutes: deletes the routes from the node // RemoveRoute: deletes the routes from the given node
func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error { func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Route) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -336,6 +341,7 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro
for _, route := range routes { for _, route := range routes {
changes = true changes = true
logging.Log.WriteInfof("deleting route: %s", route.GetDestination().String())
delete(node.Routes, route.GetDestination().String()) delete(node.Routes, route.GetDestination().String())
} }
@ -346,12 +352,12 @@ func (m *TwoPhaseStoreMeshManager) RemoveRoutes(nodeId string, routes ...mesh.Ro
return nil return nil
} }
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the bi-directionally synchroniser to merge documents
func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer { func (m *TwoPhaseStoreMeshManager) GetSyncer() mesh.MeshSyncer {
return NewTwoPhaseSyncer(m) return NewTwoPhaseSyncer(m)
} }
// GetNode get a particular not within the mesh // GetNode: get a particular not within the mesh network
func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) { func (m *TwoPhaseStoreMeshManager) GetNode(nodeId string) (mesh.MeshNode, error) {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return nil, fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -379,20 +385,20 @@ func (m *TwoPhaseStoreMeshManager) SetDescription(nodeId string, description str
return nil return nil
} }
// SetAlias: set the alias of the nodeId // SetAlias: set the alias of the given node
func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error { func (m *TwoPhaseStoreMeshManager) SetAlias(nodeId string, alias string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
} }
node := m.store.Get(nodeId) node := m.store.Get(nodeId)
node.Description = alias node.Alias = alias
m.store.Put(nodeId, node) m.store.Put(nodeId, node)
return nil return nil
} }
// AddService: adds the service to the given node // AddService: adds a service to the given node
func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error { func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -404,19 +410,25 @@ func (m *TwoPhaseStoreMeshManager) AddService(nodeId string, key string, value s
return nil return nil
} }
// RemoveService: removes the service form the node. throws an error if the service does not exist // RemoveService: removes the service form a node, throws an error if the service does not exist
func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error { func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
} }
node := m.store.Get(nodeId) node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key) delete(node.Services, key)
m.store.Put(nodeId, node) m.store.Put(nodeId, node)
return nil return nil
} }
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their vector clock in a given amount
// of time
func (m *TwoPhaseStoreMeshManager) Prune() error { func (m *TwoPhaseStoreMeshManager) Prune() error {
m.store.Prune() m.store.Prune()
return nil return nil
@ -445,6 +457,7 @@ func (m *TwoPhaseStoreMeshManager) GetPeers() []string {
}) })
} }
// getRoutes: get all routes the target node is advertising
func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) { func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Route, error) {
if !m.store.Contains(targetNode) { if !m.store.Contains(targetNode) {
return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode) return nil, fmt.Errorf("getRoute: cannot get route %s does not exist", targetNode)
@ -454,7 +467,8 @@ func (m *TwoPhaseStoreMeshManager) getRoutes(targetNode string) (map[string]Rout
return node.Routes, nil return node.Routes, nil
} }
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen // GetRoutes: Get all unique routes the target node is advertising.
// on conflicts the route with the least hop count is chosen
func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) { func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode) node, err := m.GetNode(targetNode)
@ -498,7 +512,7 @@ func (m *TwoPhaseStoreMeshManager) GetRoutes(targetNode string) (map[string]mesh
return routes, nil return routes, nil
} }
// RemoveNode(): remove the node from the mesh // RemoveNode: remove the node from the mesh
func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error { func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
if !m.store.Contains(nodeId) { if !m.store.Contains(nodeId) {
return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId) return fmt.Errorf("datastore: %s does not exist in the mesh", nodeId)
@ -508,7 +522,8 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
return nil return nil
} }
// GetConfiguration implements mesh.MeshProvider. // GetConfiguration gets the WireGuard configuration to use for this
// network
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration { func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf return m.Conf
} }

439
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,439 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.OUTGOING_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
SyncInterval: 2,
Heartbeat: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

View File

@ -4,34 +4,39 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// TwoPhaseMapFactory: instantiate a new twophasemap
// datastore
type TwoPhaseMapFactory struct { type TwoPhaseMapFactory struct {
Config *conf.DaemonConfiguration Config *conf.DaemonConfiguration
} }
// CreateMesh: create a new mesh network
func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) { func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return &TwoPhaseStoreMeshManager{ return &TwoPhaseStoreMeshManager{
MeshId: params.MeshId, MeshId: params.MeshId,
IfName: params.DevName, IfName: params.DevName,
Client: params.Client, Client: params.Client,
conf: params.Conf, Conf: params.Conf,
daemonConf: params.DaemonConf, DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a() h := fnv.New64a()
h.Write([]byte(s)) h.Write([]byte(s))
return h.Sum64() return h.Sum64()
}, uint64(3*f.Config.KeepAliveTime)), }, uint64(3*f.Config.Heartbeat)),
}, nil }, nil
} }
// MeshNodeFactory: create a new node in the mesh network
type MeshNodeFactory struct { type MeshNodeFactory struct {
Config conf.DaemonConfiguration Config conf.DaemonConfiguration
} }
// Build: build a new mesh network
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
@ -66,7 +71,7 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else { } else {
ipFunc := lib.GetPublicIP ipFunc := lib.GetPublicIP
if *params.MeshConfig.IPDiscovery == conf.DNS_IP_DISCOVERY { if *params.MeshConfig.IPDiscovery == conf.OUTGOING_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP ipFunc = lib.GetOutboundIP
} }

View File

@ -1,4 +1,4 @@
// crdt is a golang implementation of a crdt // crdt provides go implementations for crdts
package crdt package crdt
import ( import (
@ -6,6 +6,7 @@ import (
"sync" "sync"
) )
// Bucket: bucket represents a value in the grow only map
type Bucket[D any] struct { type Bucket[D any] struct {
Vector uint64 Vector uint64
Contents D Contents D
@ -19,6 +20,7 @@ type GMap[K cmp.Ordered, D any] struct {
clock *VectorClock[K] clock *VectorClock[K]
} }
// Put: put a new entry in the grow-only-map
func (g *GMap[K, D]) Put(key K, value D) { func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Lock() g.lock.Lock()
@ -32,6 +34,8 @@ func (g *GMap[K, D]) Put(key K, value D) {
g.lock.Unlock() g.lock.Unlock()
} }
// Contains: returns whether or not the key is contained
// in the g-map
func (g *GMap[K, D]) Contains(key K) bool { func (g *GMap[K, D]) Contains(key K) bool {
return g.contains(g.clock.hashFunc(key)) return g.contains(g.clock.hashFunc(key))
} }
@ -64,11 +68,23 @@ func (g *GMap[K, D]) get(key uint64) Bucket[D] {
return bucket return bucket
} }
// Get: get the value associated with the given key
func (g *GMap[K, D]) Get(key K) D { func (g *GMap[K, D]) Get(key K) D {
if !g.Contains(key) {
var def D
return def
}
return g.get(g.clock.hashFunc(key)).Contents return g.get(g.clock.hashFunc(key)).Contents
} }
// Mark: marks the node, this means the status of the node
// is an undefined state
func (g *GMap[K, D]) Mark(key K) { func (g *GMap[K, D]) Mark(key K) {
if !g.Contains(key) {
return
}
g.lock.Lock() g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)] bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true bucket.Gravestone = true
@ -76,7 +92,7 @@ func (g *GMap[K, D]) Mark(key K) {
g.lock.Unlock() g.lock.Unlock()
} }
// IsMarked: returns true if the node is marked // IsMarked: returns true if the node is marked (in an undefined state)
func (g *GMap[K, D]) IsMarked(key K) bool { func (g *GMap[K, D]) IsMarked(key K) bool {
marked := false marked := false
@ -89,10 +105,10 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
} }
g.lock.RUnlock() g.lock.RUnlock()
return marked return marked
} }
// Keys: return all the keys in the grow-only map
func (g *GMap[K, D]) Keys() []uint64 { func (g *GMap[K, D]) Keys() []uint64 {
g.lock.RLock() g.lock.RLock()
@ -108,6 +124,7 @@ func (g *GMap[K, D]) Keys() []uint64 {
return contents return contents
} }
// Save: saves the grow only map
func (g *GMap[K, D]) Save() map[uint64]Bucket[D] { func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D]) buckets := make(map[uint64]Bucket[D])
g.lock.RLock() g.lock.RLock()
@ -120,6 +137,7 @@ func (g *GMap[K, D]) Save() map[uint64]Bucket[D] {
return buckets return buckets
} }
// SaveWithKeys: get all the values corresponding with the provided keys
func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] { func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
buckets := make(map[uint64]Bucket[D]) buckets := make(map[uint64]Bucket[D])
g.lock.RLock() g.lock.RLock()
@ -132,6 +150,7 @@ func (g *GMap[K, D]) SaveWithKeys(keys []uint64) map[uint64]Bucket[D] {
return buckets return buckets
} }
// GetClock: get all the vector clocks in the g_map
func (g *GMap[K, D]) GetClock() map[uint64]uint64 { func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64) clock := make(map[uint64]uint64)
g.lock.RLock() g.lock.RLock()
@ -144,6 +163,7 @@ func (g *GMap[K, D]) GetClock() map[uint64]uint64 {
return clock return clock
} }
// GetHash: get the hash of the g_map representing its state
func (g *GMap[K, D]) GetHash() uint64 { func (g *GMap[K, D]) GetHash() uint64 {
hash := uint64(0) hash := uint64(0)
@ -157,6 +177,7 @@ func (g *GMap[K, D]) GetHash() uint64 {
return hash return hash
} }
// Prune: prune all stale entries
func (g *GMap[K, D]) Prune() { func (g *GMap[K, D]) Prune() {
stale := g.clock.getStale() stale := g.clock.getStale()
g.lock.Lock() g.lock.Lock()

224
pkg/crdt/g_map_test.go Normal file
View File

@ -0,0 +1,224 @@
// crdt_test unit tests the crdt implementations
package crdt
import (
"hash/fnv"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
func NewGmap() *GMap[string, bool] {
vectorClock := NewVectorClock("a", func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1) // 1 second stale time
gMap := NewGMap[string, bool](vectorClock)
return gMap
}
func TestGMapPutInsertsItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
if !gMap.Contains("bruh1234") {
t.Fatalf(`value not added to map`)
}
}
func TestGMapPutReplacesItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
gMap.Put("bruh1234", false)
value := gMap.Get("bruh1234")
if value {
t.Fatalf(`value should ahve been replaced to false`)
}
}
func TestContainsValueNotPresent(t *testing.T) {
gMap := NewGmap()
if gMap.Contains("sdhjsdhsdj") {
t.Fatalf(`value should not be present in the map`)
}
}
func TestContainsValuePresent(t *testing.T) {
gMap := NewGmap()
key := "hehehehe"
gMap.Put(key, false)
if !gMap.Contains(key) {
t.Fatalf(`%s should not be present in the map`, key)
}
}
func TestGMapGetNotPresentReturnsError(t *testing.T) {
gMap := NewGmap()
value := gMap.Get("bruh123")
if value != false {
t.Fatalf(`value should be default type false`)
}
}
func TestGMapGetReturnsValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("bobdylan", true)
value := gMap.Get("bobdylan")
if !value {
t.Fatalf("value should be true but was false")
}
}
func TestMarkMarksTheValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("hello123", true)
gMap.Mark("hello123")
if !gMap.IsMarked("hello123") {
t.Fatal(`hello123 should be marked`)
}
}
func TestMarkValueNotPresent(t *testing.T) {
gMap := NewGmap()
gMap.Mark("ok123456")
}
func TestKeysMapEmpty(t *testing.T) {
gMap := NewGmap()
keys := gMap.Keys()
if len(keys) != 0 {
t.Fatal(`list of keys was not empty but should be empty`)
}
}
func TestKeysMapReturnsKeysInMap(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
keys := gMap.Keys()
if len(keys) != 3 {
t.Fatal(`key length should be 3`)
}
}
func TestSaveMapEmptyReturnsEmptyMap(t *testing.T) {
gMap := NewGmap()
saveMap := gMap.Save()
if len(saveMap) != 0 {
t.Fatal(`saves should be empty`)
}
}
func TestSaveMapReturnsMapOfBuckets(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.Save()
if len(saveMap) != 3 {
t.Fatalf(`save length should be 3`)
}
}
func TestSaveWithKeysNoKeysReturnsEmptyBucket(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.SaveWithKeys([]uint64{})
if len(saveMap) != 0 {
t.Fatalf(`save map should be empty`)
}
}
func TestSaveWithKeysReturnsIntersection(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clock := lib.MapKeys(gMap.GetClock())
clock = clock[:len(clock)-1]
values := gMap.SaveWithKeys(clock)
if len(values) != len(clock) {
t.Fatalf(`intersection not returned`)
}
}
func TestGetClockMapEmptyReturnsEmptyClock(t *testing.T) {
gMap := NewGmap()
clocks := gMap.GetClock()
if len(clocks) != 0 {
t.Fatalf(`vector clock is not empty`)
}
}
func TestGetClockReturnsAllCLocks(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clocks := lib.MapValues(gMap.GetClock())
slices.Sort(clocks)
if !slices.Equal([]uint64{0, 1, 2}, clocks) {
t.Fatalf(`clocks are invalid`)
}
}
func TestGetHashChangesHashOnValueAdded(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
prevHash := gMap.GetHash()
gMap.Put("b", true)
if prevHash == gMap.GetHash() {
t.Fatalf(`hash should be different`)
}
}
func TestPruneGarbageCollectsValuesThatHaveNotBeenUpdated(t *testing.T) {
gMap := NewGmap()
gMap.clock.Put("c", 12)
gMap.Put("c", false)
gMap.Put("a", false)
time.Sleep(4 * time.Second)
gMap.Put("a", true)
gMap.Prune()
if gMap.Contains("c") {
t.Fatalf(`a should have been pruned`)
}
}

View File

@ -3,9 +3,10 @@ package crdt
import ( import (
"cmp" "cmp"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
// TwoPhaseMap: comprises of two grow-only maps
type TwoPhaseMap[K cmp.Ordered, D any] struct { type TwoPhaseMap[K cmp.Ordered, D any] struct {
addMap *GMap[K, D] addMap *GMap[K, D]
removeMap *GMap[K, bool] removeMap *GMap[K, bool]
@ -23,7 +24,7 @@ func (m *TwoPhaseMap[K, D]) Contains(key K) bool {
return m.contains(m.Clock.hashFunc(key)) return m.contains(m.Clock.hashFunc(key))
} }
// Contains checks whether the value exists in the map // contains: checks whether the key exists in the map
func (m *TwoPhaseMap[K, D]) contains(key uint64) bool { func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
if !m.addMap.contains(key) { if !m.addMap.contains(key) {
return false return false
@ -40,6 +41,7 @@ func (m *TwoPhaseMap[K, D]) contains(key uint64) bool {
return addValue.Vector >= removeValue.Vector return addValue.Vector >= removeValue.Vector
} }
// Get: get the value corresponding with the given key
func (m *TwoPhaseMap[K, D]) Get(key K) D { func (m *TwoPhaseMap[K, D]) Get(key K) D {
var result D var result D
@ -60,18 +62,19 @@ func (m *TwoPhaseMap[K, D]) get(key uint64) D {
return m.addMap.get(key).Contents return m.addMap.get(key).Contents
} }
// Put places the key K in the map // Put: places the key K in the map with the associated data D
func (m *TwoPhaseMap[K, D]) Put(key K, data D) { func (m *TwoPhaseMap[K, D]) Put(key K, data D) {
msgSequence := m.Clock.IncrementClock() msgSequence := m.Clock.IncrementClock()
m.Clock.Put(key, msgSequence) m.Clock.Put(key, msgSequence)
m.addMap.Put(key, data) m.addMap.Put(key, data)
} }
// Mark: marks the status of the node as undetermiend
func (m *TwoPhaseMap[K, D]) Mark(key K) { func (m *TwoPhaseMap[K, D]) Mark(key K) {
m.addMap.Mark(key) m.addMap.Mark(key)
} }
// Remove removes the value from the map // Remove: removes the value from the map
func (m *TwoPhaseMap[K, D]) Remove(key K) { func (m *TwoPhaseMap[K, D]) Remove(key K) {
m.removeMap.Put(key, true) m.removeMap.Put(key, true)
} }
@ -92,6 +95,7 @@ func (m *TwoPhaseMap[K, D]) keys() []uint64 {
return keys return keys
} }
// AsList: convert the map to a list
func (m *TwoPhaseMap[K, D]) AsList() []D { func (m *TwoPhaseMap[K, D]) AsList() []D {
theList := make([]D, 0) theList := make([]D, 0)
@ -104,6 +108,8 @@ func (m *TwoPhaseMap[K, D]) AsList() []D {
return theList return theList
} }
// Snapshot: convert the map into an immutable snapshot.
// contains the contents of the add and remove map
func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] { func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
return &TwoPhaseMapSnapshot[K, D]{ return &TwoPhaseMapSnapshot[K, D]{
Add: m.addMap.Save(), Add: m.addMap.Save(),
@ -111,6 +117,8 @@ func (m *TwoPhaseMap[K, D]) Snapshot() *TwoPhaseMapSnapshot[K, D] {
} }
} }
// SnapshotFromState: create a snapshot of the intersection of values provided
// in the given state
func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] { func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPhaseMapSnapshot[K, D] {
addKeys := lib.MapKeys(state.AddContents) addKeys := lib.MapKeys(state.AddContents)
removeKeys := lib.MapKeys(state.RemoveContents) removeKeys := lib.MapKeys(state.RemoveContents)
@ -121,12 +129,18 @@ func (m *TwoPhaseMap[K, D]) SnapShotFromState(state *TwoPhaseMapState[K]) *TwoPh
} }
} }
// TwoPhaseMapState: encapsulates the state of the map
// without specifying the data that is stored
type TwoPhaseMapState[K cmp.Ordered] struct { type TwoPhaseMapState[K cmp.Ordered] struct {
// Vectors: the vector ID of each process
Vectors map[uint64]uint64 Vectors map[uint64]uint64
// AddContents: the contents of the add map
AddContents map[uint64]uint64 AddContents map[uint64]uint64
// RemoveContents: the contents of the remove map
RemoveContents map[uint64]uint64 RemoveContents map[uint64]uint64
} }
// IsMarked: returns true if the given value is marked in an undetermined state
func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool { func (m *TwoPhaseMap[K, D]) IsMarked(key K) bool {
return m.addMap.IsMarked(key) return m.addMap.IsMarked(key)
} }
@ -151,7 +165,9 @@ func (m *TwoPhaseMap[K, D]) GenerateMessage() *TwoPhaseMapState[K] {
} }
} }
func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] { // Difference: compute the set difference between the two states.
// highestStale represents the highest vector clock that has been marked as stale
func (m *TwoPhaseMapState[K]) Difference(highestStale uint64, state *TwoPhaseMapState[K]) *TwoPhaseMapState[K] {
mapState := &TwoPhaseMapState[K]{ mapState := &TwoPhaseMapState[K]{
AddContents: make(map[uint64]uint64), AddContents: make(map[uint64]uint64),
RemoveContents: make(map[uint64]uint64), RemoveContents: make(map[uint64]uint64),
@ -160,7 +176,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
for key, value := range state.AddContents { for key, value := range state.AddContents {
otherValue, ok := m.AddContents[key] otherValue, ok := m.AddContents[key]
if !ok || otherValue < value { if value > highestStale && (!ok || otherValue < value) {
mapState.AddContents[key] = value mapState.AddContents[key] = value
} }
} }
@ -168,7 +184,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
for key, value := range state.RemoveContents { for key, value := range state.RemoveContents {
otherValue, ok := m.RemoveContents[key] otherValue, ok := m.RemoveContents[key]
if !ok || otherValue < value { if value > highestStale && (!ok || otherValue < value) {
mapState.RemoveContents[key] = value mapState.RemoveContents[key] = value
} }
} }
@ -176,6 +192,7 @@ func (m *TwoPhaseMapState[K]) Difference(state *TwoPhaseMapState[K]) *TwoPhaseMa
return mapState return mapState
} }
// Merge: merge a snapshot into the map
func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) { func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
for key, value := range snapshot.Add { for key, value := range snapshot.Add {
// Gravestone is local only to that node. // Gravestone is local only to that node.
@ -190,6 +207,7 @@ func (m *TwoPhaseMap[K, D]) Merge(snapshot TwoPhaseMapSnapshot[K, D]) {
} }
} }
// Prune: garbage collect all stale entries in the map
func (m *TwoPhaseMap[K, D]) Prune() { func (m *TwoPhaseMap[K, D]) Prune() {
m.addMap.Prune() m.addMap.Prune()
m.removeMap.Prune() m.removeMap.Prune()

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
type SyncState int type SyncState int
@ -68,9 +68,16 @@ func prepare(syncer *TwoPhaseSyncer) ([]byte, bool) {
return nil, false return nil, false
} }
// Increment the clock here so the clock gets
// distributed to everyone else in the mesh
syncer.manager.store.Clock.IncrementClock()
var buffer bytes.Buffer var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer) enc := gob.NewEncoder(&buffer)
mapState := syncer.manager.store.GenerateMessage()
syncer.mapState = mapState
err = enc.Encode(*syncer.mapState) err = enc.Encode(*syncer.mapState)
if err != nil { if err != nil {
@ -96,7 +103,7 @@ func present(syncer *TwoPhaseSyncer) ([]byte, bool) {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
} }
difference := syncer.mapState.Difference(&mapState) difference := syncer.mapState.Difference(syncer.manager.store.Clock.GetStaleCount(), &mapState)
syncer.manager.store.Clock.Merge(mapState.Vectors) syncer.manager.store.Clock.Merge(mapState.Vectors)
var sendBuffer bytes.Buffer var sendBuffer bytes.Buffer
@ -164,9 +171,6 @@ func (t *TwoPhaseSyncer) RecvMessage(msg []byte) error {
func (t *TwoPhaseSyncer) Complete() { func (t *TwoPhaseSyncer) Complete() {
logging.Log.WriteInfof("SYNC COMPLETED") logging.Log.WriteInfof("SYNC COMPLETED")
if t.state >= MERGE {
t.manager.store.Clock.IncrementClock()
}
} }
func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer { func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
@ -181,7 +185,6 @@ func NewTwoPhaseSyncer(manager *TwoPhaseStoreMeshManager) *TwoPhaseSyncer {
return &TwoPhaseSyncer{ return &TwoPhaseSyncer{
manager: manager, manager: manager,
state: HASH, state: HASH,
mapState: manager.store.GenerateMessage(),
generateMessageFSM: generateMessageFsm, generateMessageFSM: generateMessageFsm,
} }
} }

View File

@ -0,0 +1,214 @@
package crdt
import (
"hash/fnv"
"slices"
"testing"
)
func NewMap(processId string) *TwoPhaseMap[string, string] {
theMap := NewTwoPhaseMap[string, string](processId, func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1)
return theMap
}
func TestTwoPhaseMapEmpty(t *testing.T) {
theMap := NewMap("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapValuePresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`should be present within the map`)
}
}
func TestTwoPhaseMapValueNotPresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("b", "")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapPutThenRemove(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present within the map`)
}
}
func TestTwoPhaseMapPutThenRemoveThenPut(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`a should be present within the map`)
}
}
func TestMarkMarksTheValueIn2PMap(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Mark("a")
if !theMap.IsMarked("a") {
t.Fatalf(`a should be marked`)
}
}
func TestAsListReturnsItemsInList(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
keys := theMap.AsList()
slices.Sort(keys)
if !slices.Equal([]string{"bob", "dylan"}, keys) {
t.Fatalf(`values should be bob, dylan`)
}
}
func TestSnapShotRemoveMapEmpty(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 2 {
t.Fatalf(`add values length should be 2`)
}
if len(snapshot.Remove) != 0 {
t.Fatalf(`remove map length should be 0`)
}
}
func TestSnapshotMapEmpty(t *testing.T) {
theMap := NewMap("a")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 0 || len(snapshot.Remove) != 0 {
t.Fatalf(`snapshot length should be 0`)
}
}
func TestSnapShotFromStateReturnsIntersection(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "heyy")
map2 := NewMap("b")
map2.Put("b", "hmmm")
message := map2.GenerateMessage()
snapShot := map1.SnapShotFromState(message)
if len(snapShot.Add) != 1 {
t.Fatalf(`add length should be 1`)
}
if len(snapShot.Remove) != 0 {
t.Fatalf(`remove length should be 0`)
}
}
func TestGetHashDifferentOnChange(t *testing.T) {
theMap := NewMap("a")
prevHash := theMap.GetHash()
theMap.Put("b", "hmmhmhmh")
if prevHash == theMap.GetHash() {
t.Fatalf(`hashes should not be the same`)
}
}
func TestGenerateMessageReturnsClocks(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "hmm")
theMap.Put("b", "hmm")
theMap.Remove("a")
message := theMap.GenerateMessage()
if len(message.AddContents) != 2 {
t.Fatalf(`two items added add should be 2`)
}
if len(message.RemoveContents) != 1 {
t.Fatalf(`a was removed remove map should be length 1`)
}
}
func TestDifferenceReturnsDifferenceOfMaps(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
if len(difference.AddContents) != 2 {
t.Fatalf(`d and c are not in map1 they should be in add contents`)
}
if len(difference.RemoveContents) != 0 {
t.Fatalf(`remove should be empty`)
}
}
func TestMergeMergesValuesThatAreGreaterThanCurrentClock(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
state := map2.SnapShotFromState(difference)
map1.Merge(*state)
if !map1.Contains("d") {
t.Fatalf(`d should be in the map`)
}
if !map2.Contains("c") {
t.Fatalf(`c should be in the map`)
}
}

View File

@ -5,9 +5,12 @@ import (
"sync" "sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
// VectorBucket: represents a vector clock in the bucket
// recording both the time changes were last seen
// and when the lastUpdate epoch was recorded
type VectorBucket struct { type VectorBucket struct {
// clock current value of the node's clock // clock current value of the node's clock
clock uint64 clock uint64
@ -15,14 +18,17 @@ type VectorBucket struct {
lastUpdate uint64 lastUpdate uint64
} }
// Vector clock defines an abstract data type // VectorClock: defines an abstract data type
// for a vector clock implementation // for a vector clock implementation. Including a mechanism to
// garbage collect stale entries
type VectorClock[K cmp.Ordered] struct { type VectorClock[K cmp.Ordered] struct {
vectors map[uint64]*VectorBucket vectors map[uint64]*VectorBucket
lock sync.RWMutex lock sync.RWMutex
processID K processID K
staleTime uint64 staleTime uint64
hashFunc func(K) uint64 hashFunc func(K) uint64
// highest update that's been garbage collected
highestStale uint64
} }
// IncrementClock: increments the node's value in the vector clock // IncrementClock: increments the node's value in the vector clock
@ -60,6 +66,7 @@ func (m *VectorClock[K]) GetHash() uint64 {
return hash return hash
} }
// Merge: merge two clocks together
func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) { func (m *VectorClock[K]) Merge(vectors map[uint64]uint64) {
for key, value := range vectors { for key, value := range vectors {
m.put(key, value) m.put(key, value)
@ -78,6 +85,7 @@ func (m *VectorClock[K]) getStale() []uint64 {
for key, bucket := range m.vectors { for key, bucket := range m.vectors {
if maxTimeStamp-bucket.lastUpdate > m.staleTime { if maxTimeStamp-bucket.lastUpdate > m.staleTime {
toRemove = append(toRemove, key) toRemove = append(toRemove, key)
m.highestStale = max(bucket.clock, m.highestStale)
} }
} }
@ -85,6 +93,16 @@ func (m *VectorClock[K]) getStale() []uint64 {
return toRemove return toRemove
} }
// GetStaleCount: returns a vector clock which is considered to be stale.
// all updates must be greater than this
func (m *VectorClock[K]) GetStaleCount() uint64 {
m.lock.RLock()
staleCount := m.highestStale
m.lock.RUnlock()
return staleCount
}
// Prune: prunes all stale entries in the vector clock
func (m *VectorClock[K]) Prune() { func (m *VectorClock[K]) Prune() {
stale := m.getStale() stale := m.getStale()
@ -97,6 +115,8 @@ func (m *VectorClock[K]) Prune() {
m.lock.Unlock() m.lock.Unlock()
} }
// GetTimeStamp: get the last time the node was updated in UNIX
// epoch time
func (m *VectorClock[K]) GetTimestamp(processId K) uint64 { func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
m.lock.RLock() m.lock.RLock()
@ -106,6 +126,8 @@ func (m *VectorClock[K]) GetTimestamp(processId K) uint64 {
return lastUpdate return lastUpdate
} }
// Put: places the key with vector clock in the clock of the given
// process
func (m *VectorClock[K]) Put(key K, value uint64) { func (m *VectorClock[K]) Put(key K, value uint64) {
m.put(m.hashFunc(key), value) m.put(m.hashFunc(key), value)
} }
@ -120,7 +142,10 @@ func (m *VectorClock[K]) put(key uint64, value uint64) {
clockValue = bucket.clock clockValue = bucket.clock
} }
if value > clockValue { // Make sure that entries that were garbage collected don't get
// highestStale represents the highest vector clock that has been
// invalidated
if value > clockValue && value > m.highestStale {
newBucket := VectorBucket{ newBucket := VectorBucket{
clock: value, clock: value,
lastUpdate: uint64(time.Now().Unix()), lastUpdate: uint64(time.Now().Unix()),
@ -131,6 +156,7 @@ func (m *VectorClock[K]) put(key uint64, value uint64) {
m.lock.Unlock() m.lock.Unlock()
} }
// GetClock: serialize the vector clock into an immutable map
func (m *VectorClock[K]) GetClock() map[uint64]uint64 { func (m *VectorClock[K]) GetClock() map[uint64]uint64 {
clock := make(map[uint64]uint64) clock := make(map[uint64]uint64)

View File

@ -1,27 +1,27 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt" "github.com/tim-beatham/smegmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/sync"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
// NewCtrlServerParams are the params requried to create a new ctrl server // NewCtrlServerParams are the params required to create a new ctrl server
type NewCtrlServerParams struct { type NewCtrlServerParams struct {
Conf *conf.DaemonConfiguration Conf *conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
Querier query.Querier Querier query.Querier
OnDelete func(mesh.MeshProvider)
} }
// Create a new instance of the MeshCtrlServer or error if the // Create a new instance of the MeshCtrlServer or error if the
@ -34,11 +34,15 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := &crdt.MeshNodeFactory{ nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.IDNameGenerator{} idGenerator := &lib.ShortIDGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer() ctrlServer.timers = make([]*lib.Timer, 0)
configApplier := mesh.NewWgMeshConfigApplier()
var syncer sync.Syncer
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,
@ -48,12 +52,18 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IdGenerator: idGenerator, IdGenerator: idGenerator,
IPAllocator: ipAllocator, IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator, InterfaceManipulator: interfaceManipulator,
ConfigApplyer: configApplyer, ConfigApplier: configApplier,
OnDelete: params.OnDelete, OnDelete: func(mesh mesh.MeshProvider) {
_, err := syncer.Sync(mesh)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
},
} }
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams) ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
configApplyer.SetMeshManager(ctrlServer.MeshManager) configApplier.SetMeshManager(ctrlServer.MeshManager)
ctrlServer.Conf = params.Conf ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManagerParams{ connManagerParams := conn.NewConnectionManagerParams{
@ -83,9 +93,37 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return nil, err return nil, err
} }
syncer = sync.NewSyncer(&sync.NewSyncerParams{
MeshManager: ctrlServer.MeshManager,
ConnectionManager: ctrlServer.ConnectionManager,
Configuration: params.Conf,
})
// Check any syncs every 1 second
syncTimer := lib.NewTimer(func() error {
err = syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return nil
}, 1)
heartbeatTimer := lib.NewTimer(func() error {
logging.Log.WriteInfof("checking heartbeat")
return ctrlServer.MeshManager.UpdateTimeStamp()
}, params.Conf.Heartbeat)
ctrlServer.timers = append(ctrlServer.timers, syncTimer, heartbeatTimer)
ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager) ctrlServer.Querier = query.NewJmesQuerier(ctrlServer.MeshManager)
ctrlServer.ConnectionServer = connServer ctrlServer.ConnectionServer = connServer
for _, timer := range ctrlServer.timers {
go timer.Run()
}
return ctrlServer, nil return ctrlServer, nil
} }
@ -123,5 +161,13 @@ func (s *MeshCtrlServer) Close() error {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
} }
for _, timer := range s.timers {
err := timer.Stop()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
}
return nil return nil
} }

View File

@ -1,20 +1,35 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "net"
"github.com/tim-beatham/wgmesh/pkg/conn" "time"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshRoute: represents a route in the mesh that is
// available to client applications
type MeshRoute struct { type MeshRoute struct {
Destination string Destination string
Path []string Path []string
} }
// Represents a WireGuard MeshNode // WireGuardStats: Represents the WireGuard configuration attached to the node
type WireGuardStats struct {
AllowedIPs []string
TransmitBytes int64
ReceivedBytes int64
PersistentKeepAliveInterval time.Duration
}
// MeshNode: represents a node in the WireGuard mesh that can be
// sent to ip chandlers
type MeshNode struct { type MeshNode struct {
HostEndpoint string HostEndpoint string
WgEndpoint string WgEndpoint string
@ -25,14 +40,16 @@ type MeshNode struct {
Description string Description string
Alias string Alias string
Services map[string]string Services map[string]string
Stats WireGuardStats
} }
// Represents a WireGuard Mesh // Mesh: Represents a WireGuard Mesh network that can be sent
// along ipc to client frameworks
type Mesh struct { type Mesh struct {
SharedKey *wgtypes.Key
Nodes map[string]MeshNode Nodes map[string]MeshNode
} }
// CtrlServer: Encapsulates th ctrlserver
type CtrlServer interface { type CtrlServer interface {
GetConfiguration() *conf.DaemonConfiguration GetConfiguration() *conf.DaemonConfiguration
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
@ -42,7 +59,7 @@ type CtrlServer interface {
GetConnectionManager() conn.ConnectionManager GetConnectionManager() conn.ConnectionManager
} }
// Represents a ctrlserver to be used in WireGuard // MeshCtrlServer: Represents a ctrlserver to be used in WireGuard
type MeshCtrlServer struct { type MeshCtrlServer struct {
Client *wgctrl.Client Client *wgctrl.Client
MeshManager mesh.MeshManager MeshManager mesh.MeshManager
@ -50,4 +67,55 @@ type MeshCtrlServer struct {
ConnectionServer *conn.ConnectionServer ConnectionServer *conn.ConnectionServer
Conf *conf.DaemonConfiguration Conf *conf.DaemonConfiguration
Querier query.Querier Querier query.Querier
timers []*lib.Timer
}
// NewCtrlNode create an instance of a ctrl node to send over an
// IPC call
func NewCtrlNode(provider mesh.MeshProvider, node mesh.MeshNode) *MeshNode {
pubKey, _ := node.GetPublicKey()
ctrlNode := MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) MeshRoute {
return MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
device, err := provider.GetDevice()
if err != nil {
return &ctrlNode
}
peers := lib.Filter(device.Peers, func(p wgtypes.Peer) bool {
return p.PublicKey.String() == pubKey.String()
})
if len(peers) > 0 {
peer := peers[0]
stats := WireGuardStats{
AllowedIPs: lib.Map(peer.AllowedIPs, func(i net.IPNet) string {
return i.String()
}),
TransmitBytes: peer.TransmitBytes,
ReceivedBytes: peer.ReceiveBytes,
PersistentKeepAliveInterval: peer.PersistentKeepaliveInterval,
}
ctrlNode.Stats = stats
}
return &ctrlNode
} }

View File

@ -1,10 +1,10 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )

View File

@ -1,24 +1,22 @@
// smegdns: example of how to implement dns in the mesh
package smegdns package smegdns
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/rpc"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.` const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct { type DNSHandler struct {
client *rpc.Client client *ipc.SmegmeshIpc
server *dns.Server server *dns.Server
} }
@ -27,7 +25,7 @@ type DNSHandler struct {
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP { func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{ err := d.client.Query(ipc.QueryMesh{
MeshId: meshId, MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias), Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply) }, &reply)
@ -48,6 +46,7 @@ func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
return ip return ip
} }
// handleQuery: handles a DNS query
func (d *DNSHandler) handleQuery(m *dns.Msg) { func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question { for _, q := range m.Question {
switch q.Qtype { switch q.Qtype {
@ -75,6 +74,7 @@ func (d *DNSHandler) handleQuery(m *dns.Msg) {
} }
} }
// handleDNS query: handle a DNS request
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.SetReply(r) msg.SetReply(r)
@ -97,7 +97,7 @@ func (h *DNSHandler) Close() error {
} }
func NewDns(udpPort int) (*DNSHandler, error) { func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
return nil, err return nil, err

249
pkg/dot/dot.go Normal file
View File

@ -0,0 +1,249 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH GraphType = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
PARALLELOGRAM Shape = "parallelogram"
)
type Graph interface {
Dottable
GetType() GraphType
}
// Cluster: represents a subgraph in the graphs
type Cluster struct {
Type GraphType
Name string
Label string
nodes map[string]*Node
edges map[string]Edge
}
// RootGraph: Represents the top level graph
type RootGraph struct {
Type GraphType
Label string
nodes map[string]*Node
clusters map[string]*Cluster
edges map[string]Edge
}
// Node: represents a graphviz not
type Node struct {
Name string
Label string
Shape Shape
Size int
}
// Edge: represents an edge between adjacent nodes
type Edge interface {
Dottable
}
// DirectEdge: contains a directed edge between any two nodes
type DirectedEdge struct {
Name string
Label string
From string
To string
}
// UndirectedEdge: contains an undirected edge between any two
// nodes
type UndirectedEdge struct {
Name string
Label string
From string
To string
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
// PutNode: puts a node in the root graph
func (g *RootGraph) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Size: size, Shape: shape}
return nil
}
// PutCluster: puts a cluster in the root graph
func (g *RootGraph) PutCluster(graph *Cluster) {
g.clusters[graph.Label] = graph
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
// GetDOT: convert the root graph into dot format
func (g *RootGraph) GetDOT() (string, error) {
var result strings.Builder
result.WriteString(fmt.Sprintf("%s {\n", g.Type))
result.WriteString("node [colorscheme=set312];\n")
result.WriteString("layout = fdp;\n")
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&result, nodes...)
writeContituents(&result, edges...)
for _, cluster := range g.clusters {
clusterDOT, err := cluster.GetDOT()
if err != nil {
return "", err
}
result.WriteString(clusterDOT)
}
result.WriteString("}")
return result.String(), nil
}
// GetType: get the graph type. DIRECTED|UNDIRECTED
func (r *RootGraph) GetType() GraphType {
return r.Type
}
func constructEdge(graph Graph, name, label, from, to string) Edge {
switch graph.GetType() {
case DIGRAPH:
return &DirectedEdge{Name: name, Label: label, From: from, To: to}
default:
return &UndirectedEdge{Name: name, Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the root graph
func (g *RootGraph) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
// GetDOT: convert the node into DOT format
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[label=\"%s\",shape=%s, style=\"filled\", fillcolor=%d, width=%d, height=%d, fixedsize=true] \"%s\";\n",
n.Label, n.Shape, n.hash(), n.Size, n.Size, n.Name), nil
}
// GetDOT: Convert a directed edge into dot format
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -> \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// GetDOT: convert an undirected edge into dot format
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -- \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Cluster) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
// PutNode: puts a node in the graph
func (g *Cluster) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Shape: shape, Size: size}
return nil
}
// GetDOT: convert the cluster into dot format
func (g *Cluster) GetDOT() (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("subgraph \"cluster%s\" {\n", g.Label))
builder.WriteString(fmt.Sprintf("label = \"%s\"\n", g.Label))
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&builder, nodes...)
writeContituents(&builder, edges...)
builder.WriteString("}\n")
return builder.String(), nil
}
// GetType: get the type of the subgraph (directed|undirected)
func (g *Cluster) GetType() GraphType {
return g.Type
}
// NewSubGraph: instantiate a new subgraph
func NewSubGraph(name string, label string, graphType GraphType) *Cluster {
return &Cluster{
Label: name,
Type: graphType,
Name: name,
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}
// NewGraph: create a new root graph
func NewGraph(label string, graphType GraphType) *RootGraph {
return &RootGraph{
Type: graphType,
Label: label,
clusters: map[string]*Cluster{},
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}

116
pkg/dot/wg.go Normal file
View File

@ -0,0 +1,116 @@
package graph
import (
"fmt"
"slices"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate() (string, error)
}
type MeshDOTConverter struct {
meshes map[string][]ctrlserver.MeshNode
destinations map[string]interface{}
}
func (c *MeshDOTConverter) Generate() (string, error) {
g := NewGraph("Smegmesh", GRAPH)
for meshId := range c.meshes {
err := c.generateMesh(g, meshId)
if err != nil {
return "", err
}
}
for mesh := range c.meshes {
g.PutNode(mesh, mesh, 1, CIRCLE)
}
for destination := range c.destinations {
g.PutNode(destination, destination, 1, HEXAGON)
}
return g.GetDOT()
}
func (c *MeshDOTConverter) generateMesh(g *RootGraph, meshId string) error {
nodes := c.meshes[meshId]
g.PutNode(meshId, meshId, 1, CIRCLE)
for _, node := range nodes {
c.graphNode(g, node, meshId)
}
for _, node := range nodes {
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, meshId), "", node.PublicKey, meshId)
}
return nil
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *RootGraph, node ctrlserver.MeshNode, meshId string) {
alias := node.Alias
if alias == "" {
alias = node.WgHost[1:len(node.WgHost)-20] + "\\n" + node.WgHost[len(node.WgHost)-20:len(node.WgHost)]
}
g.PutNode(node.PublicKey, alias, 2, CIRCLE)
for _, route := range node.Routes {
if len(route.Path) == 0 {
g.AddEdge(route.Destination, "", node.PublicKey, route.Destination)
continue
}
reversedPath := slices.Clone(route.Path)
slices.Reverse(reversedPath)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, reversedPath[0]), "", node.PublicKey, reversedPath[0])
for _, mesh := range route.Path {
if _, ok := c.meshes[mesh]; !ok {
c.destinations[mesh] = struct{}{}
}
}
for index := range reversedPath[0 : len(reversedPath)-1] {
routeID := fmt.Sprintf("%s to %s", reversedPath[index], reversedPath[index+1])
g.AddEdge(routeID, "", reversedPath[index], reversedPath[index+1])
}
if route.Destination == "::/0" {
c.destinations[route.Destination] = struct{}{}
lastMesh := reversedPath[len(reversedPath)-1]
routeID := fmt.Sprintf("%s to %s", lastMesh, route.Destination)
g.AddEdge(routeID, "", lastMesh, route.Destination)
}
}
for service := range node.Services {
c.putService(g, service, meshId, node)
}
}
// putService: construct a service node and a link between the nodes
func (c *MeshDOTConverter) putService(g *RootGraph, key, meshId string, node ctrlserver.MeshNode) {
serviceID := fmt.Sprintf("%s%s%s", key, node.PublicKey, meshId)
g.PutNode(serviceID, key, 1, PARALLELOGRAM)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, serviceID), "", node.PublicKey, serviceID)
}
func NewMeshGraphConverter(meshes map[string][]ctrlserver.MeshNode) MeshGraphConverter {
return &MeshDOTConverter{
meshes: meshes,
destinations: make(map[string]interface{}),
}
}

View File

@ -1,178 +0,0 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"errors"
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
)
type Graph struct {
Type GraphType
Label string
nodes map[string]*Node
edges []Edge
}
type Node struct {
Name string
Shape Shape
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Label string
From *Node
To *Node
}
type UndirectedEdge struct {
Label string
From *Node
To *Node
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *Graph {
return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)}
}
// PutNode: puts a node in the graph
func (g *Graph) PutNode(label string, shape Shape) error {
_, exists := g.nodes[label]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[label] = &Node{Name: label, Shape: shape}
return nil
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *Graph) GetDOT() (string, error) {
var result strings.Builder
_, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type))
if err != nil {
return "", err
}
_, err = result.WriteString("node [colorscheme=set312];\n")
if err != nil {
return "", err
}
nodes := lib.MapValues(g.nodes)
err = writeContituents(&result, nodes...)
if err != nil {
return "", err
}
err = writeContituents(&result, g.edges...)
if err != nil {
return "", err
}
_, err = result.WriteString("}")
if err != nil {
return "", err
}
return result.String(), nil
}
func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge {
switch g.Type {
case DIGRAPH:
return &DirectedEdge{Label: label, From: from, To: to}
default:
return &UndirectedEdge{Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Graph) AddEdge(label string, from string, to string) error {
fromNode, exists := g.nodes[from]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", from))
}
toNode, exists := g.nodes[to]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", to))
}
g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode))
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[shape=%s, style=\"filled\", fillcolor=%d] %s;\n",
n.Shape, n.hash(), n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil
}

212
pkg/grpc/ctrlserver.pb.go Normal file
View File

@ -0,0 +1,212 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type GetMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Mesh []byte `protobuf:"bytes,1,opt,name=mesh,proto3" json:"mesh,omitempty"`
}
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
if x != nil {
return x.Mesh
}
return nil
}
var File_pkg_grpc_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x19, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22,
0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d,
0x65, 0x73, 0x68, 0x32, 0x4f, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68,
0x12, 0x18, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_proto_rawDescData = file_pkg_grpc_ctrlserver_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_proto_goTypes = []interface{}{
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_proto_depIdxs = []int32{
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_proto_init() }
func file_pkg_grpc_ctrlserver_proto_init() {
if File_pkg_grpc_ctrlserver_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_proto = out.File
file_pkg_grpc_ctrlserver_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_proto_goTypes = nil
file_pkg_grpc_ctrlserver_proto_depIdxs = nil
}

View File

@ -1,18 +0,0 @@
syntax = "proto3";
package rpctypes;
option go_package = "pkg/rpc";
service Authentication {
rpc JoinMesh(JoinAuthMeshRequest) returns (JoinAuthMeshReply) {}
}
message JoinAuthMeshRequest {
string meshId = 1;
string alias = 2;
}
message JoinAuthMeshReply {
bool success = 1;
optional string token = 2;
}

View File

@ -0,0 +1,105 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// MeshCtrlServerClient is the client API for MeshCtrlServer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
}
type meshCtrlServerClient struct {
cc grpc.ClientConnInterface
}
func NewMeshCtrlServerClient(cc grpc.ClientConnInterface) MeshCtrlServerClient {
return &meshCtrlServerClient{cc}
}
func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) {
out := new(GetMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/GetMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
// UnimplementedMeshCtrlServerServer must be embedded to have forward compatible implementations.
type UnimplementedMeshCtrlServerServer struct {
}
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to MeshCtrlServerServer will
// result in compilation errors.
type UnsafeMeshCtrlServerServer interface {
mustEmbedUnimplementedMeshCtrlServerServer()
}
func RegisterMeshCtrlServerServer(s grpc.ServiceRegistrar, srv MeshCtrlServerServer) {
s.RegisterService(&MeshCtrlServer_ServiceDesc, srv)
}
func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).GetMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/GetMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).GetMesh(ctx, req.(*GetMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.MeshCtrlServer",
HandlerType: (*MeshCtrlServerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver.proto",
}

233
pkg/grpc/syncservice.pb.go Normal file
View File

@ -0,0 +1,233 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type SyncMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshRequest) Reset() {
*x = SyncMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshRequest) ProtoMessage() {}
func (x *SyncMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshRequest.ProtoReflect.Descriptor instead.
func (*SyncMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{0}
}
func (x *SyncMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *SyncMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
type SyncMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Changes []byte `protobuf:"bytes,2,opt,name=changes,proto3" json:"changes,omitempty"`
}
func (x *SyncMeshReply) Reset() {
*x = SyncMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SyncMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SyncMeshReply) ProtoMessage() {}
func (x *SyncMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_syncservice_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SyncMeshReply.ProtoReflect.Descriptor instead.
func (*SyncMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_syncservice_proto_rawDescGZIP(), []int{1}
}
func (x *SyncMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *SyncMeshReply) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
var File_pkg_grpc_syncservice_proto protoreflect.FileDescriptor
var file_pkg_grpc_syncservice_proto_rawDesc = []byte{
0x0a, 0x1a, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x73, 0x79, 0x6e, 0x63, 0x73,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x73, 0x79,
0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x43, 0x0a, 0x0f, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06,
0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65,
0x73, 0x68, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x22, 0x43,
0x0a, 0x0d, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12,
0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x73, 0x32, 0x59, 0x0a, 0x0b, 0x53, 0x79, 0x6e, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x12, 0x4a, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1c,
0x2e, 0x73, 0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e,
0x63, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x73,
0x79, 0x6e, 0x63, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x79, 0x6e, 0x63, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x09,
0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
file_pkg_grpc_syncservice_proto_rawDescOnce sync.Once
file_pkg_grpc_syncservice_proto_rawDescData = file_pkg_grpc_syncservice_proto_rawDesc
)
func file_pkg_grpc_syncservice_proto_rawDescGZIP() []byte {
file_pkg_grpc_syncservice_proto_rawDescOnce.Do(func() {
file_pkg_grpc_syncservice_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_syncservice_proto_rawDescData)
})
return file_pkg_grpc_syncservice_proto_rawDescData
}
var file_pkg_grpc_syncservice_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_syncservice_proto_goTypes = []interface{}{
(*SyncMeshRequest)(nil), // 0: syncservice.SyncMeshRequest
(*SyncMeshReply)(nil), // 1: syncservice.SyncMeshReply
}
var file_pkg_grpc_syncservice_proto_depIdxs = []int32{
0, // 0: syncservice.SyncService.SyncMesh:input_type -> syncservice.SyncMeshRequest
1, // 1: syncservice.SyncService.SyncMesh:output_type -> syncservice.SyncMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_syncservice_proto_init() }
func file_pkg_grpc_syncservice_proto_init() {
if File_pkg_grpc_syncservice_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_syncservice_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_syncservice_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SyncMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_syncservice_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_syncservice_proto_goTypes,
DependencyIndexes: file_pkg_grpc_syncservice_proto_depIdxs,
MessageInfos: file_pkg_grpc_syncservice_proto_msgTypes,
}.Build()
File_pkg_grpc_syncservice_proto = out.File
file_pkg_grpc_syncservice_proto_rawDesc = nil
file_pkg_grpc_syncservice_proto_goTypes = nil
file_pkg_grpc_syncservice_proto_depIdxs = nil
}

View File

@ -4,18 +4,9 @@ package syncservice;
option go_package = "pkg/rpc"; option go_package = "pkg/rpc";
service SyncService { service SyncService {
rpc GetConf(GetConfRequest) returns (GetConfReply) {}
rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {} rpc SyncMesh(stream SyncMeshRequest) returns (stream SyncMeshReply) {}
} }
message GetConfRequest {
string meshId = 1;
}
message GetConfReply {
bytes mesh = 1;
}
message SyncMeshRequest { message SyncMeshRequest {
string meshId = 1; string meshId = 1;
bytes changes = 2; bytes changes = 2;

View File

@ -0,0 +1,137 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/syncservice.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// SyncServiceClient is the client API for SyncService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type SyncServiceClient interface {
SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error)
}
type syncServiceClient struct {
cc grpc.ClientConnInterface
}
func NewSyncServiceClient(cc grpc.ClientConnInterface) SyncServiceClient {
return &syncServiceClient{cc}
}
func (c *syncServiceClient) SyncMesh(ctx context.Context, opts ...grpc.CallOption) (SyncService_SyncMeshClient, error) {
stream, err := c.cc.NewStream(ctx, &SyncService_ServiceDesc.Streams[0], "/syncservice.SyncService/SyncMesh", opts...)
if err != nil {
return nil, err
}
x := &syncServiceSyncMeshClient{stream}
return x, nil
}
type SyncService_SyncMeshClient interface {
Send(*SyncMeshRequest) error
Recv() (*SyncMeshReply, error)
grpc.ClientStream
}
type syncServiceSyncMeshClient struct {
grpc.ClientStream
}
func (x *syncServiceSyncMeshClient) Send(m *SyncMeshRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *syncServiceSyncMeshClient) Recv() (*SyncMeshReply, error) {
m := new(SyncMeshReply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncServiceServer is the server API for SyncService service.
// All implementations must embed UnimplementedSyncServiceServer
// for forward compatibility
type SyncServiceServer interface {
SyncMesh(SyncService_SyncMeshServer) error
mustEmbedUnimplementedSyncServiceServer()
}
// UnimplementedSyncServiceServer must be embedded to have forward compatible implementations.
type UnimplementedSyncServiceServer struct {
}
func (UnimplementedSyncServiceServer) SyncMesh(SyncService_SyncMeshServer) error {
return status.Errorf(codes.Unimplemented, "method SyncMesh not implemented")
}
func (UnimplementedSyncServiceServer) mustEmbedUnimplementedSyncServiceServer() {}
// UnsafeSyncServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SyncServiceServer will
// result in compilation errors.
type UnsafeSyncServiceServer interface {
mustEmbedUnimplementedSyncServiceServer()
}
func RegisterSyncServiceServer(s grpc.ServiceRegistrar, srv SyncServiceServer) {
s.RegisterService(&SyncService_ServiceDesc, srv)
}
func _SyncService_SyncMesh_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(SyncServiceServer).SyncMesh(&syncServiceSyncMeshServer{stream})
}
type SyncService_SyncMeshServer interface {
Send(*SyncMeshReply) error
Recv() (*SyncMeshRequest, error)
grpc.ServerStream
}
type syncServiceSyncMeshServer struct {
grpc.ServerStream
}
func (x *syncServiceSyncMeshServer) Send(m *SyncMeshReply) error {
return x.ServerStream.SendMsg(m)
}
func (x *syncServiceSyncMeshServer) Recv() (*SyncMeshRequest, error) {
m := new(SyncMeshRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SyncService_ServiceDesc is the grpc.ServiceDesc for SyncService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SyncService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "syncservice.SyncService",
HandlerType: (*SyncServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "SyncMesh",
Handler: _SyncService_SyncMesh_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "pkg/grpc/syncservice.proto",
}

View File

@ -1,132 +0,0 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -1,8 +1,7 @@
package ip package ip
/* // Generates a CGA see RFC 3972
* Use a WireGuard public key to generate a unique interface ID // https://datatracker.ietf.org/doc/html/rfc3972
*/
import ( import (
"crypto/rand" "crypto/rand"
@ -22,19 +21,23 @@ const (
InterfaceIdLen = 8 InterfaceIdLen = 8
) )
/* // CGAParameters: parameters used to create a new cryotpgraphically generated
* Cga parameters used to generate an IPV6 interface ID // address
*/
type CgaParameters struct { type CgaParameters struct {
Modifier [ModifierLength]byte Modifier [ModifierLength]byte
// SubnetPrefix: prefix of the subnetwork
SubnetPrefix [2 * InterfaceIdLen]byte SubnetPrefix [2 * InterfaceIdLen]byte
// CollisionCount: total number of times we have atempted to generate a porefix
CollisionCount uint8 CollisionCount uint8
// PublicKey: WireGuard public key of our interface
PublicKey wgtypes.Key PublicKey wgtypes.Key
// interfaceId: the generated interfaceId
interfaceId [2 * InterfaceIdLen]byte interfaceId [2 * InterfaceIdLen]byte
// flag: represents whether or not an IP address has been generated
flag byte flag byte
} }
func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) { func NewCga(key wgtypes.Key, collisionCount uint8, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
var params CgaParameters var params CgaParameters
_, err := rand.Read(params.Modifier[:]) _, err := rand.Read(params.Modifier[:])
@ -45,25 +48,10 @@ func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParamet
params.PublicKey = key params.PublicKey = key
params.SubnetPrefix = subnetPrefix params.SubnetPrefix = subnetPrefix
params.CollisionCount = collisionCount
return &params, nil return &params, nil
} }
func (c *CgaParameters) generateHash2() []byte {
var byteVal [hash2Length]byte
for i := 0; i < ModifierLength; i++ {
byteVal[i] = c.Modifier[i]
}
for i := 0; i < wgtypes.KeyLen; i++ {
byteVal[ModifierLength+ZeroLength+i] = c.PublicKey[i]
}
hash := sha1.Sum(byteVal[:])
return hash[:Hash2Prefix]
}
func (c *CgaParameters) generateHash1() []byte { func (c *CgaParameters) generateHash1() []byte {
var byteVal [hash1Length]byte var byteVal [hash1Length]byte
@ -78,7 +66,6 @@ func (c *CgaParameters) generateHash1() []byte {
byteVal[hash1Length-1] = c.CollisionCount byteVal[hash1Length-1] = c.CollisionCount
hash := sha1.Sum(byteVal[:]) hash := sha1.Sum(byteVal[:])
return hash[:Hash1Prefix] return hash[:Hash1Prefix]
} }
@ -90,9 +77,6 @@ func clearBit(num, pos int) byte {
} }
func (c *CgaParameters) generateInterface() []byte { func (c *CgaParameters) generateInterface() []byte {
// TODO: On duplicate address detection increment collision.
// Also incorporate SEC
hash1 := c.generateHash1() hash1 := c.generateHash1()
var interfaceId []byte = make([]byte, InterfaceIdLen) var interfaceId []byte = make([]byte, InterfaceIdLen)
@ -101,7 +85,6 @@ func (c *CgaParameters) generateInterface() []byte {
interfaceId[0] = clearBit(int(interfaceId[0]), 6) interfaceId[0] = clearBit(int(interfaceId[0]), 6)
interfaceId[0] = clearBit(int(interfaceId[1]), 7) interfaceId[0] = clearBit(int(interfaceId[1]), 7)
return interfaceId return interfaceId
} }

View File

@ -6,6 +6,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// IPAllocator: abstracts the process of creating an IP address
type IPAllocator interface { type IPAllocator interface {
GetIP(key wgtypes.Key, meshId string) (net.IP, error) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error)
} }

View File

@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// ULABuilder: Create a new ULA in WireGuard
type ULABuilder struct{} type ULABuilder struct{}
func getMeshPrefix(meshId string) [16]byte { func getMeshPrefix(meshId string) [16]byte {
@ -39,10 +40,10 @@ func (u *ULABuilder) GetIPNet(meshId string) (*net.IPNet, error) {
return net, nil return net, nil
} }
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string) (net.IP, error) { func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error) {
ulaPrefix := getMeshPrefix(meshId) ulaPrefix := getMeshPrefix(meshId)
c, err := NewCga(key, ulaPrefix) c, err := NewCga(key, collisionCount, ulaPrefix)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -5,74 +5,191 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
ipcRPC "net/rpc"
"os" "os"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
) )
type NewMeshArgs struct { const SockAddr = "/tmp/smeg.sock"
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args *JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
Query(query QueryMesh, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
// WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct {
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// KeepAliveWg is the number of seconds to keep alive
// for WireGuard NAT/firewall traversal
KeepAliveWg int
// AdvertiseRoutes whether or not to advertise routes to and from the
// mesh network
AdvertiseRoutes bool
// AdvertiseDefaultRoute whether or not to advertise the default route
// into the mesh network
AdvertiseDefaultRoute bool
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
// or DNS entry // or DNS entry
Endpoint string Endpoint string
// Role is the role of the individual in the mesh
Role string Role string
} }
type NewMeshArgs struct {
// WgArgs are specific WireGuard args to use
WgArgs WireGuardArgs
}
type JoinMeshArgs struct { type JoinMeshArgs struct {
// MeshId is the ID of the mesh to join // MeshId is the ID of the mesh to join
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAddress string
// Port is the WireGuard port to expose // WgArgs is the WireGuard parameters to use.
Port int WgArgs WireGuardArgs
// Endpoint to use to override the default
Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
Role string
} }
// PutServiceArgs: args to place a service into the data store
type PutServiceArgs struct { type PutServiceArgs struct {
Service string Service string
Value string Value string
MeshId string
} }
// DeleteServiceArgs: args to remove a service from the data store
type DeleteServiceArgs struct {
Service string
MeshId string
}
// PutAliasArgs: args to assign an alias to a node
type PutAliasArgs struct {
// Alias: represents the alias of the node
Alias string
// MeshId: represents the meshID of the node
MeshId string
}
// PutDescriptionArgs: args to assign a description to a node
type PutDescriptionArgs struct {
// Description: descriptio to add to the network
Description string
// MeshID to add to the mesh network
MeshId string
}
// GetMeshReply: ipc reply to get the mesh network
type GetMeshReply struct { type GetMeshReply struct {
Nodes []ctrlserver.MeshNode Nodes []ctrlserver.MeshNode
} }
// ListMeshReply: ipc reply of the networks the node is part of
type ListMeshReply struct { type ListMeshReply struct {
Meshes []string Meshes []string
} }
// Querymesh: ipc args to query a mesh network
type QueryMesh struct { type QueryMesh struct {
// MeshId: id of the mesh to query
MeshId string MeshId string
// JMESPath: query string to query
Query string Query string
} }
type GetNodeArgs struct { // ClientIpc: Framework to invoke ipc calls to the daemon
NodeId string type ClientIpc interface {
MeshId string // CreateMesh: create a mesh network, return an error if the operation failed
}
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error // ListMesh: list mesh network the node is a part of, return an error if the operation failed
ListMeshes(args *ListMeshReply, reply *string) error
// JoinMesh: join a mesh network return an error if the operation failed
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
// LeaveMesh: leave a mesh network, return an error if the operation failed
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
// GetMesh: get the given mesh network, return an error if the operation failed
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
GetDOT(meshId string, reply *string) error // Query: query the given mesh network
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error // PutDescription: assign a description to yourself
PutAlias(alias string, reply *string) error PutDescription(args PutDescriptionArgs, reply *string) error
// PutAlias: assign an alias to yourself
PutAlias(args PutAliasArgs, reply *string) error
// PutService: assign a service to yourself
PutService(args PutServiceArgs, reply *string) error PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error // DeleteService: retract a service
DeleteService(service string, reply *string) error DeleteService(args DeleteServiceArgs, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" type SmegmeshIpc struct {
client *ipcRPC.Client
}
func NewClientIpc() (*SmegmeshIpc, error) {
client, err := ipcRPC.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
return &SmegmeshIpc{
client: client,
}, nil
}
func (c *SmegmeshIpc) CreateMesh(args *NewMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.CreateMesh", args, reply)
}
func (c *SmegmeshIpc) ListMeshes(reply *ListMeshReply) error {
return c.client.Call("IpcHandler.ListMeshes", "", reply)
}
func (c *SmegmeshIpc) JoinMesh(args JoinMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.JoinMesh", &args, reply)
}
func (c *SmegmeshIpc) LeaveMesh(meshId string, reply *string) error {
return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply)
}
func (c *SmegmeshIpc) GetMesh(meshId string, reply *GetMeshReply) error {
return c.client.Call("IpcHandler.GetMesh", &meshId, reply)
}
func (c *SmegmeshIpc) Query(query QueryMesh, reply *string) error {
return c.client.Call("IpcHandler.Query", &query, reply)
}
func (c *SmegmeshIpc) PutDescription(args PutDescriptionArgs, reply *string) error {
return c.client.Call("IpcHandler.PutDescription", &args, reply)
}
func (c *SmegmeshIpc) PutAlias(args PutAliasArgs, reply *string) error {
return c.client.Call("IpcHandler.PutAlias", &args, reply)
}
func (c *SmegmeshIpc) PutService(args PutServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.PutService", &args, reply)
}
func (c *SmegmeshIpc) DeleteService(args DeleteServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.DeleteService", &args, reply)
}
func (c *SmegmeshIpc) Close() error {
return c.client.Close()
}
func RunIpcHandler(server MeshIpc) error { func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil { if err := os.RemoveAll(SockAddr); err != nil {

View File

@ -17,7 +17,7 @@ func HashString(value string) int {
} }
// ConsistentHash implementation. Traverse the values until we find a key // ConsistentHash implementation. Traverse the values until we find a key
// less than ours. // greater than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V { func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 { if len(values) == 0 {
panic("values is empty") panic("values is empty")
@ -36,11 +36,13 @@ func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int,
ourKey := keyFunc(client) ourKey := keyFunc(client)
for _, record := range vs { idx := sort.Search(len(vs), func(i int) bool {
if ourKey < record.value { return vs[i].value >= ourKey
return record.record })
}
if idx == len(vs) {
return vs[0].record
} }
return vs[0].record return vs[idx].record
} }

View File

@ -3,6 +3,7 @@ package lib
import ( import (
"github.com/anandvarma/namegen" "github.com/anandvarma/namegen"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lithammer/shortuuid"
) )
// IdGenerator generates unique ids // IdGenerator generates unique ids
@ -19,6 +20,14 @@ func (g *UUIDGenerator) GetId() (string, error) {
return id.String(), nil return id.String(), nil
} }
type ShortIDGenerator struct {
}
func (g *ShortIDGenerator) GetId() (string, error) {
id := shortuuid.New()
return id, nil
}
type IDNameGenerator struct { type IDNameGenerator struct {
} }

View File

@ -6,27 +6,21 @@ import (
"net" "net"
"github.com/jsimonetti/rtnetlink" "github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// Maximum MTU to assin to WireGuard
// This isn't configurable
const WIREGUARD_MTU = 1420
// RtNetlinkConfig: represents an rtnetlkink configuration instance
type RtNetlinkConfig struct { type RtNetlinkConfig struct {
// conn: connection to the rtnetlink API
conn *rtnetlink.Conn conn *rtnetlink.Conn
} }
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) { // CreateLink: Create a netlink interface if it does not exist. ifName is the name of the netlink interface
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error { func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName) _, err := net.InterfaceByName(ifName)
@ -51,7 +45,7 @@ func (c *RtNetlinkConfig) CreateLink(ifName string) error {
return nil return nil
} }
// Delete link delete the specified interface // DeleteLink: delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error { func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -68,7 +62,7 @@ func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
return nil return nil
} }
// AddAddress adds an address to the given interface. // AddAddress: adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error { func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -177,7 +171,7 @@ func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
return nil return nil
} }
// DeleteRoute deletes routes with the gateway and destination // DeleteRoute: deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error { func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -219,17 +213,21 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
return nil return nil
} }
// route: represents a rout to add to the RIB
type Route struct { type Route struct {
Gateway net.IP Gateway net.IP
Destination net.IPNet Destination net.IPNet
} }
func (r1 Route) equal(r2 Route) bool { func (r1 Route) equal(r2 Route) bool {
mask1Ones, _ := r1.Destination.Mask.Size()
mask2Ones, _ := r2.Destination.Mask.Size()
return r1.Gateway.String() == r2.Gateway.String() && return r1.Gateway.String() == r2.Gateway.String() &&
r1.Destination.String() == r2.Destination.String() (mask1Ones == 0 && mask2Ones == 0 || r1.Destination.IP.Equal(r2.Destination.IP))
} }
// DeleteRoutes deletes all routes not in exclude // DeleteRoutes: deletes all routes not in exclude on the given interface
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes, err := c.listRoutes(ifName, family) routes, err := c.listRoutes(ifName, family)
@ -257,18 +255,11 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
shouldExclude := func(r Route) bool { shouldExclude := func(r Route) bool {
for _, route := range exclude { for _, route := range exclude {
if route.equal(r) { if r.equal(route) {
return false return false
} }
}
if family == unix.AF_INET && route.Destination.IP.To4() == nil {
return false
}
if family == unix.AF_INET6 && route.Destination.IP.To16() == nil {
return false
}
}
return true return true
} }
@ -286,7 +277,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
return nil return nil
} }
// listRoutes lists all routes on the interface // listRoutes: lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) { func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
@ -308,6 +299,18 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.R
return routes, nil return routes, nil
} }
// Close: close the Rtnetlink API
func (c *RtNetlinkConfig) Close() error { func (c *RtNetlinkConfig) Close() error {
return c.conn.Close() return c.conn.Close()
} }
// newRtNetlinkConfig: connect to the RtnetlinkAPI
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}

View File

@ -1,40 +0,0 @@
// lib contains helper functions for the implementation
package lib
import (
"cmp"
"math"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
// Modelling the distribution using a normal distribution get the count
// of the outliers
func GetOutliers[K cmp.Ordered](counts map[K]uint64, alpha float64) []K {
n := float64(len(counts))
keys := MapKeys(counts)
values := make([]float64, len(keys))
for index, key := range keys {
values[index] = float64(counts[key])
}
mean := stat.Mean(values, nil)
stdDev := stat.StdDev(values, nil)
moe := distuv.Normal{Mu: 0, Sigma: 1}.Quantile(1-alpha/2) * (stdDev / math.Sqrt(n))
lowerBound := mean - moe
var outliers []K
for i, count := range values {
if count < lowerBound {
outliers = append(outliers, keys[i])
}
}
return outliers
}

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tim-beatham/smegmesh/pkg/conf"
) )
var ( var (
@ -39,17 +40,29 @@ func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer() return l.logger.Writer()
} }
func NewLogrusLogger() *LogrusLogger { func NewLogrusLogger(confLevel conf.LogLevel) *LogrusLogger {
var level logrus.Level
switch confLevel {
case conf.ERROR:
level = logrus.ErrorLevel
case conf.WARNING:
level = logrus.WarnLevel
case conf.INFO:
level = logrus.InfoLevel
}
logger := logrus.New() logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
logger.SetOutput(os.Stdout) logger.SetOutput(os.Stdout)
logger.SetLevel(logrus.InfoLevel) logger.SetLevel(level)
return &LogrusLogger{logger: logger} return &LogrusLogger{logger: logger}
} }
func init() { func init() {
SetLogger(NewLogrusLogger()) SetLogger(NewLogrusLogger(conf.INFO))
} }
func SetLogger(l Logger) { func SetLogger(l Logger) {

View File

@ -1,46 +0,0 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -7,22 +7,23 @@ import (
"strings" "strings"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route" "github.com/tim-beatham/smegmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshConfigApplyer abstracts applying the mesh configuration // MeshConfigApplier abstracts applying the mesh configuration
type MeshConfigApplyer interface { type MeshConfigApplier interface {
// ApplyConfig: apply the configurtation
ApplyConfig() error ApplyConfig() error
RemovePeers(meshId string) error // SetMeshManager: sets the associated manager
SetMeshManager(manager MeshManager) SetMeshManager(manager MeshManager)
} }
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplier: applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplier struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller routeInstaller route.RouteInstaller
hashFunc func(MeshNode) int hashFunc func(MeshNode) int
@ -35,14 +36,13 @@ type routeNode struct {
type convertMeshNodeParams struct { type convertMeshNodeParams struct {
node MeshNode node MeshNode
self MeshNode
mesh MeshProvider mesh MeshProvider
device *wgtypes.Device device *wgtypes.Device
peerToClients map[string][]net.IPNet peerToClients map[string][]net.IPNet
routes map[string][]routeNode routes map[string][]routeNode
} }
func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) { func (m *WgMeshConfigApplier) convertMeshNode(params convertMeshNodeParams) (*wgtypes.PeerConfig, error) {
pubKey, err := params.node.GetPublicKey() pubKey, err := params.node.GetPublicKey()
if err != nil { if err != nil {
@ -58,8 +58,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
allowedips = append(allowedips, clients...) allowedips = append(allowedips, clients...)
} }
for _, route := range params.node.GetRoutes() { for _, bestRoutes := range lib.MapValues(params.routes) {
bestRoutes := params.routes[route.GetDestination().String()]
var pickedRoute routeNode var pickedRoute routeNode
if len(bestRoutes) == 1 { if len(bestRoutes) == 1 {
@ -69,8 +68,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
return lib.HashString(rn.gateway) return lib.HashString(rn.gateway)
} }
// Else there is more than one candidate so consistently hash pickedRoute = lib.ConsistentHash(bestRoutes, params.node, bucketFunc, m.hashFunc)
pickedRoute = lib.ConsistentHash(bestRoutes, params.self, bucketFunc, m.hashFunc)
} }
if pickedRoute.gateway == pubKey.String() { if pickedRoute.gateway == pubKey.String() {
@ -91,7 +89,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
return p.PublicKey.String() == pubKey.String() return p.PublicKey.String() == pubKey.String()
}) })
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint()) var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,8 +117,13 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
// getRoutes: finds the routes with the least hop distance. If more than one route exists // getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic // consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode { func (m *WgMeshConfigApplier) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
mesh, _ := meshProvider.GetMesh() mesh, err := meshProvider.GetMesh()
if err != nil {
return nil, err
}
routes := make(map[string][]routeNode) routes := make(map[string][]routeNode)
peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool { peers := lib.Filter(lib.MapValues(mesh.GetNodes()), func(p MeshNode) bool {
@ -134,10 +141,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
for _, route := range node.GetRoutes() { for _, route := range node.GetRoutes() {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool { if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
v6Default, _, _ := net.ParseCIDR("::/0") if prefix.IP.Equal(net.IPv6zero) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
v4Default, _, _ := net.ParseCIDR("0.0.0.0/0")
if (prefix.IP.Equal(v6Default) || prefix.IP.Equal(v4Default)) && *meshProvider.GetConfiguration().AdvertiseDefaultRoute {
return true return true
} }
@ -157,16 +161,17 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
// Client's only acessible by another peer // Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE { if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node) peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId()) self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) { if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey() peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String() rn.gateway = peerPub.String()
rn.route = &RouteStub{ rn.route = &RouteStub{
Destination: rn.route.GetDestination(), Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1,
// Append the path to this peer
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()), Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
} }
} }
@ -184,16 +189,17 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} }
} }
return routes return routes, nil
} }
// getCorrespondignPeer: gets the peer corresponding to the client // getCorrespondignPeer: gets the peer corresponding to the client
func (m *WgMeshConfigApplyer) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode { func (m *WgMeshConfigApplier) getCorrespondingPeer(peers []MeshNode, client MeshNode) MeshNode {
peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc) peer := lib.ConsistentHash(peers, client, m.hashFunc, m.hashFunc)
return peer return peer
} }
func (m *WgMeshConfigApplyer) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig { // getPeerCfgsToRemove: remove peer configurations that are no longer in the mesh
func (m *WgMeshConfigApplier) getPeerCfgsToRemove(dev *wgtypes.Device, newPeers []wgtypes.PeerConfig) []wgtypes.PeerConfig {
peers := dev.Peers peers := dev.Peers
peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool { peers = lib.Filter(peers, func(p1 wgtypes.Peer) bool {
return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool { return !lib.Contains(newPeers, func(p2 wgtypes.PeerConfig) bool {
@ -217,27 +223,37 @@ type GetConfigParams struct {
routes map[string][]routeNode routes map[string][]routeNode
} }
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) { // getClientConfig: if the node is a client get their configuration
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId()) func (m *WgMeshConfigApplier) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
ula := &ip.ULABuilder{} ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId()) meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode { routesForMesh := lib.Map(lib.MapValues(params.routes), func(rns []routeNode) []routeNode {
return lib.Filter(rns, func(rn routeNode) bool { return lib.Filter(rns, func(rn routeNode) bool {
ip, _, _ := net.ParseCIDR(rn.gateway) node, err := params.mesh.GetNode(rn.gateway)
return meshNet.Contains(ip) return node != nil && err == nil
}) })
}) })
routesForMesh = lib.Filter(routesForMesh, func(rns []routeNode) bool {
return len(rns) != 0
})
routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet { routes := lib.Map(routesForMesh, func(rs []routeNode) net.IPNet {
return *rs[0].route.GetDestination() return *rs[0].route.GetDestination()
}) })
routes = append(routes, *meshNet) routes = append(routes, *meshNet)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(params.peers) == 0 {
return nil, fmt.Errorf("no peers in the mesh")
}
peer := m.getCorrespondingPeer(params.peers, self) peer := m.getCorrespondingPeer(params.peers, self)
pubKey, _ := peer.GetPublicKey() pubKey, _ := peer.GetPublicKey()
@ -263,30 +279,38 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
installedRoutes := make([]lib.Route, 0) installedRoutes := make([]lib.Route, 0)
for _, route := range peerCfgs[0].AllowedIPs { for _, route := range peerCfgs[0].AllowedIPs {
// Don't install routes that we are directly a part
// Dont install default route wgctrl handles this for us
if !meshNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{ installedRoutes = append(installedRoutes, lib.Route{
Gateway: peer.GetWgHost().IP, Gateway: peer.GetWgHost().IP,
Destination: route, Destination: route,
}) })
} }
}
cfg := wgtypes.Config{ cfg := wgtypes.Config{
Peers: peerCfgs, Peers: peerCfgs,
} }
if params.dev != nil {
m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...) m.routeInstaller.InstallRoutes(params.dev.Name, installedRoutes...)
}
return &cfg, err return &cfg, err
} }
func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route { // getRoutesToInstall: work out if the given node is advertising routes that should be installed into the
// RIB
func (m *WgMeshConfigApplier) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mesh MeshProvider, node MeshNode) []lib.Route {
routes := make([]lib.Route, 0) routes := make([]lib.Route, 0)
for _, route := range wgNode.AllowedIPs { for _, route := range wgNode.AllowedIPs {
ula := &ip.ULABuilder{} ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId()) ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
_, defaultRoute, _ := net.ParseCIDR("::/0") // Check there is no overlap in network and its not the default route
if !ipNet.Contains(route.IP) {
if !ipNet.Contains(route.IP) && !ipNet.IP.Equal(defaultRoute.IP) {
routes = append(routes, lib.Route{ routes = append(routes, lib.Route{
Gateway: node.GetWgHost().IP, Gateway: node.GetWgHost().IP,
Destination: route, Destination: route,
@ -297,11 +321,12 @@ func (m *WgMeshConfigApplyer) getRoutesToInstall(wgNode *wgtypes.PeerConfig, mes
return routes return routes
} }
func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) { // getPeerConfig: creates the WireGuard configuration for a peer
func (m *WgMeshConfigApplier) getPeerConfig(params *GetConfigParams) (*wgtypes.Config, error) {
peerToClients := make(map[string][]net.IPNet) peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0) installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0) peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId()) self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return nil, err return nil, err
@ -320,10 +345,8 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
peerToClients[pubKey.String()] = append(clients, *n.GetWgHost()) peerToClients[pubKey.String()] = append(clients, *n.GetWgHost())
if NodeEquals(self, peer) {
cfg, err := m.convertMeshNode(convertMeshNodeParams{ cfg, err := m.convertMeshNode(convertMeshNodeParams{
node: n, node: n,
self: self,
mesh: params.mesh, mesh: params.mesh,
device: params.dev, device: params.dev,
peerToClients: peerToClients, peerToClients: peerToClients,
@ -334,9 +357,11 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
return nil, err return nil, err
} }
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...) if NodeEquals(self, peer) {
peerConfigs = append(peerConfigs, *cfg) peerConfigs = append(peerConfigs, *cfg)
} }
installedRoutes = append(installedRoutes, m.getRoutesToInstall(cfg, params.mesh, n)...)
} }
} }
@ -347,7 +372,6 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
peer, err := m.convertMeshNode(convertMeshNodeParams{ peer, err := m.convertMeshNode(convertMeshNodeParams{
node: n, node: n,
self: self,
mesh: params.mesh, mesh: params.mesh,
peerToClients: peerToClients, peerToClients: peerToClients,
routes: params.routes, routes: params.routes,
@ -370,7 +394,8 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
return &cfg, err return &cfg, err
} }
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error { // updateWgConf: update the WireGuard configuration
func (m *WgMeshConfigApplier) updateWgConf(mesh MeshProvider, routes map[string][]routeNode) error {
snap, err := mesh.GetMesh() snap, err := mesh.GetMesh()
if err != nil { if err != nil {
@ -378,7 +403,11 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
} }
nodes := lib.MapValues(snap.GetNodes()) nodes := lib.MapValues(snap.GetNodes())
dev, _ := mesh.GetDevice() dev, err := mesh.GetDevice()
if err != nil {
return err
}
slices.SortFunc(nodes, func(a, b MeshNode) int { slices.SortFunc(nodes, func(a, b MeshNode) int {
return strings.Compare(string(a.GetType()), string(b.GetType())) return strings.Compare(string(a.GetType()), string(b.GetType()))
@ -392,7 +421,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return mn.GetType() == conf.CLIENT_ROLE return mn.GetType() == conf.CLIENT_ROLE
}) })
self, err := m.meshManager.GetSelf(mesh.GetMeshId()) self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return err return err
@ -431,11 +460,17 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return nil return nil
} }
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode { // getAllRoutes: works out all the routes to install out of all the routes in the
// set of networks the node is a part of
func (m *WgMeshConfigApplier) getAllRoutes() (map[string][]routeNode, error) {
allRoutes := make(map[string][]routeNode) allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
routes := m.getRoutes(mesh) routes, err := m.getRoutes(mesh)
if err != nil {
return nil, err
}
for destination, route := range routes { for destination, route := range routes {
_, ok := allRoutes[destination] _, ok := allRoutes[destination]
@ -453,11 +488,16 @@ func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
} }
} }
return allRoutes return allRoutes, nil
} }
func (m *WgMeshConfigApplyer) ApplyConfig() error { // ApplyConfig: apply the WireGuard configuration
allRoutes := m.getAllRoutes() func (m *WgMeshConfigApplier) ApplyConfig() error {
allRoutes, err := m.getAllRoutes()
if err != nil {
return err
}
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes) err := m.updateWgConf(mesh, allRoutes)
@ -470,33 +510,12 @@ func (m *WgMeshConfigApplyer) ApplyConfig() error {
return nil return nil
} }
func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error { func (m *WgMeshConfigApplier) SetMeshManager(manager MeshManager) {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
Peers: make([]wgtypes.PeerConfig, 0),
ReplacePeers: true,
})
return nil
}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }
func NewWgMeshConfigApplyer() MeshConfigApplyer { func NewWgMeshConfigApplier() MeshConfigApplier {
return &WgMeshConfigApplyer{ return &WgMeshConfigApplier{
routeInstaller: route.NewRouteInstaller(), routeInstaller: route.NewRouteInstaller(),
hashFunc: func(mn MeshNode) int { hashFunc: func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey() pubKey, _ := mn.GetPublicKey()

View File

@ -1,77 +0,0 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/graph"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
mesh := c.manager.GetMesh(meshId)
if mesh == nil {
return "", errors.New("mesh does not exist")
}
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
for _, node2 := range nodes[i+1:] {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
node1Id := fmt.Sprintf("\"%s\"", node1.GetIdentifier())
node2Id := fmt.Sprintf("\"%s\"", node2.GetIdentifier())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
self, _ := c.manager.GetSelf(meshId)
if NodeEquals(self, node) {
return
}
for _, route := range node.GetRoutes() {
routeId := fmt.Sprintf("\"%s\"", route)
g.PutNode(routeId, graph.HEXAGON)
g.AddEdge(fmt.Sprintf("%s to %s", nodeId, routeId), nodeId, routeId)
}
}
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,17 +3,21 @@ package mesh
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"sync" "sync"
"github.com/tim-beatham/wgmesh/pkg/cmd" "github.com/tim-beatham/smegmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// MeshManager: abstracts maanging meshes, including installing the WireGuard configuration
// to the device, and adding and removing nodes
type MeshManager interface { type MeshManager interface {
CreateMesh(params *CreateMeshParams) (string, error) CreateMesh(params *CreateMeshParams) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
@ -24,72 +28,72 @@ type MeshManager interface {
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(meshId, description string) error
SetAlias(alias string) error SetAlias(meshId, alias string) error
SetService(service string, value string) error SetService(meshId, service, value string) error
RemoveService(service string) error RemoveService(meshId, service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
Close() error Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode GetNode(string, string) MeshNode
GetRouteManager() RouteManager GetRouteManager() RouteManager
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
lock sync.RWMutex meshLock sync.RWMutex
Meshes map[string]MeshProvider meshes map[string]MeshProvider
RouteManager RouteManager RouteManager RouteManager
Client *wgctrl.Client Client *wgctrl.Client
// HostParameters contains information that uniquely locates
// the node in the mesh network.
HostParameters *HostParameters HostParameters *HostParameters
conf *conf.DaemonConfiguration conf *conf.DaemonConfiguration
meshProviderFactory MeshProviderFactory meshProviderFactory MeshProviderFactory
nodeFactory MeshNodeFactory nodeFactory MeshNodeFactory
configApplyer MeshConfigApplyer configApplier MeshConfigApplier
idGenerator lib.IdGenerator idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
cmdRunner cmd.CmdRunner cmdRunner cmd.CmdRunner
OnDelete func(MeshProvider) OnDelete func(MeshProvider)
} }
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager { func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager return m.RouteManager
} }
// RemoveService implements MeshManager. // RemoveService: remove a service from the given mesh.
func (m *MeshManagerImpl) RemoveService(service string) error { func (m *MeshManagerImpl) RemoveService(meshId, service string) error {
for _, mesh := range m.Meshes { mesh := m.GetMesh(meshId)
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
} }
return nil if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
} }
// SetService implements MeshManager. // SetService: add a service to the given mesh
func (m *MeshManagerImpl) SetService(service string, value string) error { func (m *MeshManagerImpl) SetService(meshId, service, value string) error {
for _, mesh := range m.Meshes { mesh := m.GetMesh(meshId)
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
} }
return nil if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
} }
// GetNode: gets the node with given id in the mesh network
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode { func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid] mesh, ok := m.meshes[meshid]
if !ok { if !ok {
return nil return nil
@ -104,11 +108,6 @@ func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
return node return node
} }
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
}
// CreateMeshParams contains the parameters required to create a mesh // CreateMeshParams contains the parameters required to create a mesh
type CreateMeshParams struct { type CreateMeshParams struct {
Port int Port int
@ -141,6 +140,10 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", err return "", err
} }
if *meshConfiguration.Role == conf.CLIENT_ROLE {
return "", fmt.Errorf("cannot create mesh as a client")
}
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = "" var ifName string = ""
@ -173,9 +176,9 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
m.lock.Lock() m.meshLock.Lock()
m.Meshes[meshId] = nodeManager m.meshes[meshId] = nodeManager
m.lock.Unlock() m.meshLock.Unlock()
m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...) m.cmdRunner.RunCommands(m.conf.BaseConfiguration.PostUp...)
@ -189,7 +192,7 @@ type AddMeshParams struct {
Conf *conf.WgConfiguration Conf *conf.WgConfiguration
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add a new mesh network to the list of addresses
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string var ifName string
var err error var err error
@ -232,20 +235,20 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
return err return err
} }
m.lock.Lock() m.meshLock.Lock()
m.Meshes[params.MeshId] = meshProvider m.meshes[params.MeshId] = meshProvider
m.lock.Unlock() m.meshLock.Unlock()
return nil return nil
} }
// HasChanges returns true if the mesh has changes // HasChanges: returns true if the mesh has changes
func (m *MeshManagerImpl) HasChanges(meshId string) bool { func (m *MeshManagerImpl) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges() return m.meshes[meshId].HasChanges()
} }
// GetMesh returns the mesh with the given meshid // GetMesh: returns the mesh with the given meshid
func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider { func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
theMesh := m.Meshes[meshId] theMesh := m.meshes[meshId]
return theMesh return theMesh
} }
@ -255,6 +258,8 @@ func (s *MeshManagerImpl) GetPublicKey() *wgtypes.Key {
return &key return &key
} }
// AddSelfParams: parameters required to add yourself to a mesh
// network
type AddSelfParams struct { type AddSelfParams struct {
// MeshId is the ID of the mesh to add this instance to // MeshId is the ID of the mesh to add this instance to
MeshId string MeshId string
@ -264,7 +269,7 @@ type AddSelfParams struct {
Endpoint string Endpoint string
} }
// AddSelf adds this host to the mesh // AddSelf: adds this host to the mesh
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error { func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
mesh := s.GetMesh(params.MeshId) mesh := s.GetMesh(params.MeshId)
@ -284,12 +289,38 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
pubKey := s.HostParameters.PrivateKey.PublicKey() pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId) collisionCount := uint8(0)
var nodeIP net.IP
// Perform Duplicate Address Detection with the nodes
// that are already in the network
for {
generatedIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId, collisionCount)
if err != nil { if err != nil {
return err return err
} }
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
proposition := func(node MeshNode) bool {
ipNet := node.GetWgHost()
return ipNet.IP.Equal(nodeIP)
}
if lib.Contains(lib.MapValues(snapshot.GetNodes()), proposition) {
collisionCount++
} else {
nodeIP = generatedIP
break
}
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: &pubKey, PublicKey: &pubKey,
NodeIP: nodeIP, NodeIP: nodeIP,
@ -312,11 +343,11 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
} }
} }
s.Meshes[params.MeshId].AddNode(node) s.meshes[params.MeshId].AddNode(node)
return nil return nil
} }
// LeaveMesh leaves the mesh network // LeaveMesh: leaves the mesh network and force a synchronsiation
func (s *MeshManagerImpl) LeaveMesh(meshId string) error { func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
mesh := s.GetMesh(meshId) mesh := s.GetMesh(meshId)
@ -327,16 +358,16 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
err := mesh.RemoveNode(s.HostParameters.GetPublicKey()) err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err logging.Log.WriteErrorf(err.Error())
} }
if s.OnDelete != nil { if s.OnDelete != nil {
s.OnDelete(mesh) s.OnDelete(mesh)
} }
s.lock.Lock() s.meshLock.Lock()
delete(s.Meshes, meshId) delete(s.meshes, meshId)
s.lock.Unlock() s.meshLock.Unlock()
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...) s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PreDown...)
@ -355,12 +386,11 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
} }
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...) s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err return err
} }
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) { func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId] meshInstance, ok := s.meshes[meshId]
if !ok { if !ok {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
@ -375,51 +405,46 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return node, nil return node, nil
} }
// ApplyConfig: applies the WireGuard configuration
// adds routes to the RIB and so forth.
func (s *MeshManagerImpl) ApplyConfig() error { func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg { if s.conf.StubWg {
return nil return nil
} }
return s.configApplier.ApplyConfig()
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(meshId, description string) error {
meshes := s.GetMeshes() mesh := s.GetMesh(meshId)
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
}
} }
return nil if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
} }
// SetAlias implements MeshManager. // SetAlias sets the alias of the node for the given meshid
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(meshId, alias string) error {
meshes := s.GetMeshes() mesh := s.GetMesh(meshId)
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
} }
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
} }
}
return nil return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
} }
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp: updates the timestamp of this node in all meshes
// essentially performs heartbeat if the node is the leader
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
meshes := s.GetMeshes() meshes := s.GetMeshes()
for _, mesh := range meshes { for _, mesh := range meshes {
@ -439,26 +464,30 @@ func (s *MeshManagerImpl) GetClient() *wgctrl.Client {
return s.Client return s.Client
} }
// GetMeshes: get all meshes the node is part of
func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider { func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
meshes := make(map[string]MeshProvider) meshes := make(map[string]MeshProvider)
s.lock.RLock() // GetMesh: copies the map of meshes to a new map
// to prevent a whole range of concurrency issues
// due to iteration and modification
s.meshLock.RLock()
for id, mesh := range s.Meshes { for id, mesh := range s.meshes {
meshes[id] = mesh meshes[id] = mesh
} }
s.lock.RUnlock() s.meshLock.RUnlock()
return meshes return meshes
} }
// Close the mesh manager // Close: close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg { if s.conf.StubWg {
return nil return nil
} }
for _, mesh := range s.Meshes { for _, mesh := range s.meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()
if err != nil { if err != nil {
@ -475,7 +504,7 @@ func (s *MeshManagerImpl) Close() error {
return nil return nil
} }
// NewMeshManagerParams params required to create an instance of a mesh manager // NewMeshManagerParams: params required to create an instance of a mesh manager
type NewMeshManagerParams struct { type NewMeshManagerParams struct {
Conf conf.DaemonConfiguration Conf conf.DaemonConfiguration
Client *wgctrl.Client Client *wgctrl.Client
@ -484,13 +513,13 @@ type NewMeshManagerParams struct {
IdGenerator lib.IdGenerator IdGenerator lib.IdGenerator
IPAllocator ip.IPAllocator IPAllocator ip.IPAllocator
InterfaceManipulator wg.WgInterfaceManipulator InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer ConfigApplier MeshConfigApplier
RouteManager RouteManager RouteManager RouteManager
CommandRunner cmd.CmdRunner CommandRunner cmd.CmdRunner
OnDelete func(MeshProvider) OnDelete func(MeshProvider)
} }
// Creates a new instance of a mesh manager with the given parameters // NewMeshManager: Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
privateKey, _ := wgtypes.GeneratePrivateKey() privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{ hostParams := HostParameters{
@ -498,7 +527,7 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
} }
m := &MeshManagerImpl{ m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider), meshes: make(map[string]MeshProvider),
HostParameters: &hostParams, HostParameters: &hostParams,
meshProviderFactory: params.MeshProvider, meshProviderFactory: params.MeshProvider,
nodeFactory: params.NodeFactory, nodeFactory: params.NodeFactory,
@ -506,11 +535,11 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
conf: &params.Conf, conf: &params.Conf,
} }
m.configApplyer = params.ConfigApplyer m.configApplier = params.ConfigApplier
m.RouteManager = params.RouteManager m.RouteManager = params.RouteManager
if m.RouteManager == nil { if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m, &params.Conf) m.RouteManager = NewRouteManager(m)
} }
if params.CommandRunner == nil { if params.CommandRunner == nil {
@ -521,11 +550,6 @@ func NewMeshManager(params *NewMeshManagerParams) MeshManager {
m.ipAllocator = params.IPAllocator m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
m.OnDelete = params.OnDelete m.OnDelete = params.OnDelete
return m return m
} }

View File

@ -3,15 +3,38 @@ package mesh
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/wg"
) )
func getMeshConfiguration() *conf.DaemonConfiguration { func getMeshConfiguration() *conf.DaemonConfiguration {
advertiseRoutes := true
advertiseDefaultRoute := true
ipDiscovery := conf.PUBLIC_IP_DISCOVERY
role := conf.PEER_ROLE
return &conf.DaemonConfiguration{ return &conf.DaemonConfiguration{
GrpcPort: 8080, GrpcPort: 8080,
CertificatePath: "./somecertificatepath",
PrivateKeyPath: "./someprivatekeypath",
CaCertificatePath: "./somecacertificatepath",
SkipCertVerification: true,
Timeout: 5,
StubWg: true,
SyncInterval: 2,
Heartbeat: 60,
ClusterSize: 64,
InterClusterChance: 0.15,
Branch: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
} }
} }
@ -24,7 +47,7 @@ func getMeshManager() MeshManager {
IdGenerator: &lib.UUIDGenerator{}, IdGenerator: &lib.UUIDGenerator{},
IPAllocator: &ip.ULABuilder{}, IPAllocator: &ip.ULABuilder{},
InterfaceManipulator: &wg.WgInterfaceManipulatorStub{}, InterfaceManipulator: &wg.WgInterfaceManipulatorStub{},
ConfigApplyer: &MeshConfigApplyerStub{}, ConfigApplier: &MeshConfigApplierStub{},
RouteManager: &RouteManagerStub{}, RouteManager: &RouteManagerStub{},
}) })
@ -34,7 +57,10 @@ func getMeshManager() MeshManager {
func TestCreateMeshCreatesANewMeshProvider(t *testing.T) { func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId, err := manager.CreateMesh("wg0", 5000) meshId, err := manager.CreateMesh(&CreateMeshParams{
Port: 0,
Conf: &conf.WgConfiguration{},
})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -121,7 +147,7 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
t.Error(err) t.Error(err)
} }
_, ok := mesh.GetNodes()["abc.com"] _, ok := mesh.GetNodes()[manager.GetPublicKey().String()]
if !ok { if !ok {
t.Fatalf(`node has not been added`) t.Fatalf(`node has not been added`)
@ -186,36 +212,80 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
} }
} }
func TestSetDescription(t *testing.T) { func TestSetAliasUpdatesAliasOfNode(t *testing.T) {
manager := getMeshManager()
alias := "Firpo"
meshId, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetAlias(meshId, alias)
if err != nil {
t.Fatalf(`failed to set the alias`)
}
self, err := manager.GetSelf(meshId)
if err != nil {
t.Fatalf(`failed to set the alias err: %s`, err.Error())
}
if alias != self.GetAlias() {
t.Fatalf(`alias should be %s was %s`, alias, self.GetAlias())
}
}
func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
meshId1, _ := manager.CreateMesh(5000) meshId1, _ := manager.CreateMesh(&CreateMeshParams{
meshId2, _ := manager.CreateMesh(5001) Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
WgPort: 5000, WgPort: 5000,
Endpoint: "abc.com:8080", Endpoint: "abc.com:8080",
}) })
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description) err := manager.SetDescription(meshId1, description)
if err != nil { if err != nil {
t.Fatalf(`failed to set the descriptions`) t.Fatalf(`failed to set the descriptions`)
} }
}
self1, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`failed to set the description`)
}
if description != self1.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self1.GetDescription())
}
}
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId1, _ := manager.CreateMesh(5000) meshId1, _ := manager.CreateMesh(&CreateMeshParams{
meshId2, _ := manager.CreateMesh(5001) Port: 5000,
Conf: &conf.WgConfiguration{},
})
meshId2, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5001,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
@ -234,3 +304,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
t.Fatalf(`failed to update the timestamp`) t.Fatalf(`failed to update the timestamp`)
} }
} }
func TestAddServiceAddsServiceToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
}
func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
manager.RemoveService(meshId1, serviceName)
self, err = manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; ok {
t.Fatalf(`service still exists`)
}
}

View File

@ -1,81 +0,0 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -3,26 +3,35 @@ package mesh
import ( import (
"net" "net"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/lib"
) )
// RouteManager: manager that leaks routes between meshes
type RouteManager interface { type RouteManager interface {
// UpdateRoutes: leak all routes in each mesh
UpdateRoutes() error UpdateRoutes() error
} }
type RouteManagerImpl struct { type RouteManagerImpl struct {
meshManager MeshManager meshManager MeshManager
conf *conf.DaemonConfiguration
} }
func (r *RouteManagerImpl) UpdateRoutes() error { func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.GetMeshes() meshes := r.meshManager.GetMeshes()
routes := make(map[string][]Route) routes := make(map[string][]Route)
for _, mesh := range meshes {
// Make empty routes so that routes are retracted
routes[mesh.GetMeshId()] = make([]Route, 0)
}
for _, mesh1 := range meshes { for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) if !*mesh1.GetConfiguration().AdvertiseRoutes {
continue
}
self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return err return err
@ -32,23 +41,24 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
routes[mesh1.GetMeshId()] = make([]Route, 0) routes[mesh1.GetMeshId()] = make([]Route, 0)
} }
if *mesh1.GetConfiguration().AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
defaultRoute := &RouteStub{
Destination: ipv6Default,
Path: []string{mesh1.GetMeshId()},
}
mesh1.AddRoutes(NodeID(self), defaultRoute)
routes[mesh1.GetMeshId()] = append(routes[mesh1.GetMeshId()], defaultRoute)
}
routeMap, err := mesh1.GetRoutes(NodeID(self)) routeMap, err := mesh1.GetRoutes(NodeID(self))
if err != nil { if err != nil {
return err return err
} }
if *r.conf.BaseConfiguration.AdvertiseDefaultRoute {
_, ipv6Default, _ := net.ParseCIDR("::/0")
mesh1.AddRoutes(NodeID(self),
&RouteStub{
Destination: ipv6Default,
HopCount: 0,
Path: make([]string, 0),
})
}
for _, mesh2 := range meshes { for _, mesh2 := range meshes {
routeValues, ok := routes[mesh2.GetMeshId()] routeValues, ok := routes[mesh2.GetMeshId()]
@ -64,7 +74,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
routeValues = append(routeValues, &RouteStub{ routeValues = append(routeValues, &RouteStub{
Destination: mesh1IpNet, Destination: mesh1IpNet,
HopCount: 0,
Path: []string{mesh1.GetMeshId()}, Path: []string{mesh1.GetMeshId()},
}) })
@ -75,8 +84,9 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return s == mesh2.GetMeshId() return s == mesh2.GetMeshId()
} }
// Ensure that the route does not see it's own IP // Remove any potential routing loops
return !r.GetDestination().IP.Equal(mesh2IpNet.IP) && !lib.Contains(r.GetPath()[1:], pathNotMesh) return !r.GetDestination().IP.Equal(mesh2IpNet.IP) &&
!lib.Contains(r.GetPath()[1:], pathNotMesh)
}) })
routes[mesh2.GetMeshId()] = routeValues routes[mesh2.GetMeshId()] = routeValues
@ -85,15 +95,21 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
// Calculate the set different of each, working out routes to remove and to keep. // Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes { for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId) mesh := meshes[meshId]
self, _ := r.meshManager.GetSelf(meshId)
self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
toRemove := make([]Route, 0) toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self)) prevRoutes := self.GetRoutes()
for _, route := range prevRoutes { for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool { if !lib.Contains(meshRoutes, func(r Route) bool {
return RouteEquals(r, route) return RouteEqual(r, route)
}) { }) {
toRemove = append(toRemove, route) toRemove = append(toRemove, route)
} }
@ -106,6 +122,6 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return nil return nil
} }
func NewRouteManager(m MeshManager, conf *conf.DaemonConfiguration) RouteManager { func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, conf: conf} return &RouteManagerImpl{meshManager: m}
} }

View File

@ -5,7 +5,8 @@ import (
"net" "net"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -19,6 +20,8 @@ type MeshNodeStub struct {
routes []Route routes []Route
identifier string identifier string
description string description string
alias string
services map[string]string
} }
// GetType implements MeshNode. // GetType implements MeshNode.
@ -27,13 +30,13 @@ func (*MeshNodeStub) GetType() conf.NodeType {
} }
// GetServices implements MeshNode. // GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string { func (m *MeshNodeStub) GetServices() map[string]string {
return make(map[string]string) return m.services
} }
// GetAlias implements MeshNode. // GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string { func (s *MeshNodeStub) GetAlias() string {
return "" return s.alias
} }
func (m *MeshNodeStub) GetHostEndpoint() string { func (m *MeshNodeStub) GetHostEndpoint() string {
@ -83,17 +86,26 @@ type MeshProviderStub struct {
// GetConfiguration implements MeshProvider. // GetConfiguration implements MeshProvider.
func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration { func (*MeshProviderStub) GetConfiguration() *conf.WgConfiguration {
panic("unimplemented") advertiseRoutes := true
advertiseDefaultRoute := true
ipDiscovery := conf.PUBLIC_IP_DISCOVERY
role := conf.PEER_ROLE
return &conf.WgConfiguration{
IPDiscovery: &ipDiscovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
}
} }
// Mark implements MeshProvider. // Mark implements MeshProvider.
func (*MeshProviderStub) Mark(nodeId string) { func (*MeshProviderStub) Mark(nodeId string) {
panic("unimplemented")
} }
// RemoveNode implements MeshProvider. // RemoveNode implements MeshProvider.
func (*MeshProviderStub) RemoveNode(nodeId string) error { func (*MeshProviderStub) RemoveNode(nodeId string) error {
panic("unimplemented") return nil
} }
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) { func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
@ -106,32 +118,53 @@ func (*MeshProviderStub) GetPeers() []string {
} }
// GetNode implements MeshProvider. // GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) { func (m *MeshProviderStub) GetNode(nodeId string) (MeshNode, error) {
return nil, nil return m.snapshot.nodes[nodeId], nil
} }
// NodeExists implements MeshProvider. // NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool { func (m *MeshProviderStub) NodeExists(nodeId string) bool {
return false return m.snapshot.nodes[nodeId] != nil
} }
// AddService implements MeshProvider. // AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { func (m *MeshProviderStub) AddService(nodeId string, key string, value string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.services[key] = value
return nil return nil
} }
// RemoveService implements MeshProvider. // RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error { func (m *MeshProviderStub) RemoveService(nodeId string, key string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
delete(node.services, key)
return nil return nil
} }
// SetAlias implements MeshProvider. // SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error { func (m *MeshProviderStub) SetAlias(nodeId string, alias string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.alias = alias
return nil
}
// AddRoutes implements
func (m *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.routes = append(node.routes, route...)
return nil return nil
} }
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error { func (m *MeshProviderStub) RemoveRoutes(nodeId string, route ...Route) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
newRoutes := lib.Filter(node.routes, func(r1 Route) bool {
return !lib.Contains(route, func(r2 Route) bool {
return RouteEqual(r1, r2)
})
})
node.routes = newRoutes
return nil return nil
} }
@ -141,12 +174,15 @@ func (*MeshProviderStub) Prune() error {
} }
// UpdateTimeStamp implements MeshProvider. // UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error { func (m *MeshProviderStub) UpdateTimeStamp(nodeId string) error {
node := (m.snapshot.nodes[nodeId]).(*MeshNodeStub)
node.timeStamp = time.Now().Unix()
return nil return nil
} }
func (s *MeshProviderStub) AddNode(node MeshNode) { func (s *MeshProviderStub) AddNode(node MeshNode) {
s.snapshot.nodes[node.GetHostEndpoint()] = node pubKey, _ := node.GetPublicKey()
s.snapshot.nodes[pubKey.String()] = node
} }
func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) { func (s *MeshProviderStub) GetMesh() (MeshSnapshot, error) {
@ -178,15 +214,13 @@ func (s *MeshProviderStub) HasChanges() bool {
return false return false
} }
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil
}
func (s *MeshProviderStub) GetSyncer() MeshSyncer { func (s *MeshProviderStub) GetSyncer() MeshSyncer {
return nil return nil
} }
func (s *MeshProviderStub) SetDescription(nodeId string, description string) error { func (s *MeshProviderStub) SetDescription(nodeId string, description string) error {
meshNode := (s.snapshot.nodes[nodeId]).(*MeshNodeStub)
meshNode.description = description
return nil return nil
} }
@ -209,26 +243,27 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
return &MeshNodeStub{ return &MeshNodeStub{
hostEndpoint: params.Endpoint, hostEndpoint: params.Endpoint,
publicKey: *params.PublicKey, publicKey: *params.PublicKey,
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort), wgEndpoint: fmt.Sprintf("%s:%d", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost, wgHost: wgHost,
timeStamp: time.Now().Unix(), timeStamp: time.Now().Unix(),
routes: make([]Route, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
services: make(map[string]string),
} }
} }
type MeshConfigApplyerStub struct{} type MeshConfigApplierStub struct{}
func (a *MeshConfigApplyerStub) ApplyConfig() error { func (a *MeshConfigApplierStub) ApplyConfig() error {
return nil return nil
} }
func (a *MeshConfigApplyerStub) RemovePeers(meshId string) error { func (a *MeshConfigApplierStub) RemovePeers(meshId string) error {
return nil return nil
} }
func (a *MeshConfigApplyerStub) SetMeshManager(manager MeshManager) { func (a *MeshConfigApplierStub) SetMeshManager(manager MeshManager) {
} }
type MeshManagerStub struct { type MeshManagerStub struct {
@ -237,37 +272,32 @@ type MeshManagerStub struct {
// GetRouteManager implements MeshManager. // GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager { func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented") return nil
} }
// GetNode implements MeshManager. // GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode { func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode {
panic("unimplemented") return nil
} }
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error { func (*MeshManagerStub) RemoveService(meshId, service string) error {
panic("unimplemented") return nil
} }
// SetService implements MeshManager. // SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error { func (*MeshManagerStub) SetService(meshId, service, value string) error {
panic("unimplemented") return nil
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
} }
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error { func (*MeshManagerStub) SetAlias(meshId, alias string) error {
panic("unimplemented") return nil
} }
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") return nil
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -319,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error {
return nil return nil
} }
func (m *MeshManagerStub) SetDescription(description string) error { func (m *MeshManagerStub) SetDescription(meshId, description string) error {
return nil return nil
} }

View File

@ -6,7 +6,7 @@ import (
"net" "net"
"slices" "slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -20,7 +20,7 @@ type Route interface {
GetPath() []string GetPath() []string
} }
func RouteEquals(r1, r2 Route) bool { func RouteEqual(r1 Route, r2 Route) bool {
return r1.GetDestination().String() == r2.GetDestination().String() && return r1.GetDestination().String() == r2.GetDestination().String() &&
r1.GetHopCount() == r2.GetHopCount() && r1.GetHopCount() == r2.GetHopCount() &&
slices.Equal(r1.GetPath(), r2.GetPath()) slices.Equal(r1.GetPath(), r2.GetPath())
@ -28,7 +28,6 @@ func RouteEquals(r1, r2 Route) bool {
type RouteStub struct { type RouteStub struct {
Destination *net.IPNet Destination *net.IPNet
HopCount int
Path []string Path []string
} }
@ -37,7 +36,7 @@ func (r *RouteStub) GetDestination() *net.IPNet {
} }
func (r *RouteStub) GetHopCount() int { func (r *RouteStub) GetHopCount() int {
return r.HopCount return len(r.Path)
} }
func (r *RouteStub) GetPath() []string { func (r *RouteStub) GetPath() []string {
@ -75,6 +74,10 @@ func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey() key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey() key2, _ := node2.GetPublicKey()
if node1 == nil || node2 == nil {
return false
}
return key1.String() == key2.String() return key1.String() == key2.String()
} }

View File

@ -6,9 +6,9 @@ import (
"strings" "strings"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Querier queries a data store for the given data // Querier queries a data store for the given data
@ -17,20 +17,24 @@ type Querier interface {
Query(meshId string, queryParams string) ([]byte, error) Query(meshId string, queryParams string) ([]byte, error)
} }
// JmesQuerier: queries the datstore in JMESPath syntax
type JmesQuerier struct { type JmesQuerier struct {
manager mesh.MeshManager manager mesh.MeshManager
} }
// QueryError: query error if something went wrong
type QueryError struct { type QueryError struct {
msg string msg string
} }
// QuerRoute: represents a route in the query
type QueryRoute struct { type QueryRoute struct {
Destination string `json:"destination"` Destination string `json:"destination"`
HopCount int `json:"hopCount"` HopCount int `json:"hopCount"`
Path string `json:"path"` Path string `json:"path"`
} }
// QueryNode: represents a single node in the query
type QueryNode struct { type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"` HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
@ -48,7 +52,7 @@ func (m *QueryError) Error() string {
return m.msg return m.msg
} }
// Query: queries the data // Query: queries the the datastore at the given meshid
func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) { func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
mesh, ok := j.manager.GetMeshes()[meshId] mesh, ok := j.manager.GetMeshes()[meshId]
@ -74,6 +78,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
// MeshNodeToQuerynode: convert the mesh node into a query abstraction
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode { func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()

View File

@ -1,302 +0,0 @@
package robin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
type IpcHandler struct {
Server ctrlserver.CtrlServer
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
overrideConf := &conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
if *overrideConf.Role == conf.CLIENT_ROLE {
return fmt.Errorf("cannot create a mesh with no public endpoint")
}
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgPort,
Conf: overrideConf,
})
if err != nil {
return err
}
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: meshId,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
})
if err != nil {
return err
}
*reply = meshId
return err
}
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.GetMeshManager().GetMeshes()))
i := 0
for meshId := range n.Server.GetMeshManager().GetMeshes() {
meshNames[i] = meshId
i++
}
*reply = ipc.ListMeshReply{Meshes: meshNames}
return nil
}
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
overrideConf := &conf.WgConfiguration{}
if args.Role != "" {
role := conf.NodeType(args.Role)
overrideConf.Role = &role
}
if args.Endpoint != "" {
overrideConf.Endpoint = &args.Endpoint
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress)
if err != nil {
return err
}
client, err := peerConnection.GetClient()
if err != nil {
return err
}
c := rpc.NewMeshCtrlServerClient(client)
if err != nil {
return err
}
configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
if err != nil {
return err
}
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
WgPort: args.Port,
MeshBytes: meshReply.Mesh,
Conf: overrideConf,
})
if err != nil {
return err
}
err = n.Server.GetMeshManager().AddSelf(&mesh.AddSelfParams{
MeshId: args.MeshId,
WgPort: args.Port,
Endpoint: args.Endpoint,
})
if err != nil {
return err
}
*reply = strconv.FormatBool(true)
return nil
}
// LeaveMesh leaves a mesh network
func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
err := n.Server.GetMeshManager().LeaveMesh(meshId)
if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId)
}
return err
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
if theMesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
for _, node := range meshSnapshot.GetNodes() {
pubKey, _ := node.GetPublicKey()
if err != nil {
return err
}
node := ctrlserver.MeshNode{
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
return ctrlserver.MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
i += 1
}
*reply = ipc.GetMeshReply{Nodes: nodes}
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
if err != nil {
return err
}
*reply = result
return nil
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
if err != nil {
return err
}
*reply = string(queryResponse)
return nil
}
func (n *IpcHandler) PutDescription(description string, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(description)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set description to %s", description)
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer
}
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {
return IpcHandler{
Server: ipcParams.CtrlServer,
}
}

View File

@ -1 +0,0 @@
package robin

View File

@ -1,10 +1,11 @@
package route package route
import ( import (
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// RouteInstaller: install the routes to the given interface
type RouteInstaller interface { type RouteInstaller interface {
InstallRoutes(devName string, routes ...lib.Route) error InstallRoutes(devName string, routes ...lib.Route) error
} }
@ -19,6 +20,8 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route)
return err return err
} }
defer rtnl.Close()
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...) err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil { if err != nil {

View File

@ -1,235 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type JoinAuthMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Alias string `protobuf:"bytes,2,opt,name=alias,proto3" json:"alias,omitempty"`
}
func (x *JoinAuthMeshRequest) Reset() {
*x = JoinAuthMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshRequest) ProtoMessage() {}
func (x *JoinAuthMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{0}
}
func (x *JoinAuthMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
func (x *JoinAuthMeshRequest) GetAlias() string {
if x != nil {
return x.Alias
}
return ""
}
type JoinAuthMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
Token *string `protobuf:"bytes,2,opt,name=token,proto3,oneof" json:"token,omitempty"`
}
func (x *JoinAuthMeshReply) Reset() {
*x = JoinAuthMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinAuthMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinAuthMeshReply) ProtoMessage() {}
func (x *JoinAuthMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinAuthMeshReply.ProtoReflect.Descriptor instead.
func (*JoinAuthMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP(), []int{1}
}
func (x *JoinAuthMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *JoinAuthMeshReply) GetToken() string {
if x != nil && x.Token != nil {
return *x.Token
}
return ""
}
var File_pkg_grpc_ctrlserver_authentication_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = []byte{
0x0a, 0x28, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74,
0x79, 0x70, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d,
0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69,
0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18,
0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65,
0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a,
0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x48, 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65,
0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67,
0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce sync.Once
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = file_pkg_grpc_ctrlserver_authentication_proto_rawDesc
)
func file_pkg_grpc_ctrlserver_authentication_proto_rawDescGZIP() []byte {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescOnce.Do(func() {
file_pkg_grpc_ctrlserver_authentication_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_ctrlserver_authentication_proto_rawDescData)
})
return file_pkg_grpc_ctrlserver_authentication_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_authentication_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_authentication_proto_goTypes = []interface{}{
(*JoinAuthMeshRequest)(nil), // 0: rpctypes.JoinAuthMeshRequest
(*JoinAuthMeshReply)(nil), // 1: rpctypes.JoinAuthMeshReply
}
var file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = []int32{
0, // 0: rpctypes.Authentication.JoinMesh:input_type -> rpctypes.JoinAuthMeshRequest
1, // 1: rpctypes.Authentication.JoinMesh:output_type -> rpctypes.JoinAuthMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_ctrlserver_authentication_proto_init() }
func file_pkg_grpc_ctrlserver_authentication_proto_init() {
if File_pkg_grpc_ctrlserver_authentication_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinAuthMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_pkg_grpc_ctrlserver_authentication_proto_msgTypes[1].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_authentication_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_ctrlserver_authentication_proto_goTypes,
DependencyIndexes: file_pkg_grpc_ctrlserver_authentication_proto_depIdxs,
MessageInfos: file_pkg_grpc_ctrlserver_authentication_proto_msgTypes,
}.Build()
File_pkg_grpc_ctrlserver_authentication_proto = out.File
file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = nil
file_pkg_grpc_ctrlserver_authentication_proto_goTypes = nil
file_pkg_grpc_ctrlserver_authentication_proto_depIdxs = nil
}

View File

@ -1,105 +0,0 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: pkg/grpc/ctrlserver/authentication.proto
package rpc
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// AuthenticationClient is the client API for Authentication service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AuthenticationClient interface {
JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error)
}
type authenticationClient struct {
cc grpc.ClientConnInterface
}
func NewAuthenticationClient(cc grpc.ClientConnInterface) AuthenticationClient {
return &authenticationClient{cc}
}
func (c *authenticationClient) JoinMesh(ctx context.Context, in *JoinAuthMeshRequest, opts ...grpc.CallOption) (*JoinAuthMeshReply, error) {
out := new(JoinAuthMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.Authentication/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// AuthenticationServer is the server API for Authentication service.
// All implementations must embed UnimplementedAuthenticationServer
// for forward compatibility
type AuthenticationServer interface {
JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error)
mustEmbedUnimplementedAuthenticationServer()
}
// UnimplementedAuthenticationServer must be embedded to have forward compatible implementations.
type UnimplementedAuthenticationServer struct {
}
func (UnimplementedAuthenticationServer) JoinMesh(context.Context, *JoinAuthMeshRequest) (*JoinAuthMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedAuthenticationServer) mustEmbedUnimplementedAuthenticationServer() {}
// UnsafeAuthenticationServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to AuthenticationServer will
// result in compilation errors.
type UnsafeAuthenticationServer interface {
mustEmbedUnimplementedAuthenticationServer()
}
func RegisterAuthenticationServer(s grpc.ServiceRegistrar, srv AuthenticationServer) {
s.RegisterService(&Authentication_ServiceDesc, srv)
}
func _Authentication_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinAuthMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthenticationServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.Authentication/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthenticationServer).JoinMesh(ctx, req.(*JoinAuthMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// Authentication_ServiceDesc is the grpc.ServiceDesc for Authentication service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Authentication_ServiceDesc = grpc.ServiceDesc{
ServiceName: "rpctypes.Authentication",
HandlerType: (*AuthenticationServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "JoinMesh",
Handler: _Authentication_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/authentication.proto",
}

View File

@ -1,147 +1,265 @@
package sync package sync
import ( import (
"fmt"
"io" "io"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Syncer: picks random nodes from the meshs // Syncer: picks random nodes from the meshs
type Syncer interface { type Syncer interface {
Sync(meshId string) error Sync(theMesh mesh.MeshProvider) (bool, error)
SyncMeshes() error SyncMeshes() error
} }
// SyncerImpl: implementation of a syncer to sync meshes
type SyncerImpl struct { type SyncerImpl struct {
manager mesh.MeshManager meshManager mesh.MeshManager
requester SyncRequester requester SyncRequester
infectionCount int infectionCount int
syncCount int syncCount int
cluster conn.ConnCluster cluster conn.ConnCluster
conf *conf.DaemonConfiguration configuration *conf.DaemonConfiguration
lastSync uint64 lastSync map[string]int64
lastPoll map[string]int64
lastSyncLock sync.RWMutex
lastPollLock sync.RWMutex
} }
// Sync: Sync random nodes // Sync: Sync with random nodes. Returns true if there was changes false otherwise
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) (bool, error) {
// Self can be nil if the node is removed if correspondingMesh == nil {
self, _ := s.manager.GetSelf(meshId) return false, fmt.Errorf("mesh provided was nil cannot sync nil mesh")
}
correspondingMesh := s.manager.GetMesh(meshId) // Self can be nil if the node is removed
selfID := s.meshManager.GetPublicKey()
self, _ := correspondingMesh.GetNode(selfID.String())
correspondingMesh.Prune() correspondingMesh.Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { if correspondingMesh.HasChanges() {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
return nil }
// If removed sync with other nodes to gossip the node is removed
if self != nil && self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
// If not synchronised in certain time pull from random neighbour
if s.configuration.PullInterval != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.configuration.PullInterval) {
return s.Pull(self, correspondingMesh)
}
return false, nil
} }
before := time.Now() before := time.Now()
s.manager.GetRouteManager().UpdateRoutes()
publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
publicKey := s.meshManager.GetPublicKey()
nodeNames := correspondingMesh.GetPeers() nodeNames := correspondingMesh.GetPeers()
if self != nil {
nodeNames = lib.Filter(nodeNames, func(s string) bool { nodeNames = lib.Filter(nodeNames, func(s string) bool {
return s != mesh.NodeID(self) // Filter our only public key out so we dont sync with ourself
return s != publicKey.String()
}) })
}
var gossipNodes []string var gossipNodes []string
// Clients always pings its peer for configuration // Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE { if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
keyFunc := lib.HashString neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc) if len(neighbours) == 0 {
gossipNodes = make([]string, 1) return false, nil
gossipNodes[0] = neighbour }
// Peer with 2 nodes so that there is redundancy in
// the situation the node leaves pre-emptively
redundancyLength := min(len(neighbours), 2)
gossipNodes = neighbours[:redundancyLength]
} else { } else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) gossipNodes = lib.RandomSubsetOfLength(neighbours, s.configuration.Branch)
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.configuration.ClusterSize && rand.Float64() < s.configuration.InterClusterChance {
gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String()) gossipNodes[len(gossipNodes)-1] = s.cluster.GetInterCluster(nodeNames, publicKey.String())
} }
} }
var succeeded bool = false var succeeded bool = false
// Do this synchronously to conserve bandwidth
for _, node := range gossipNodes { for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node) correspondingPeer, err := correspondingMesh.GetNode(node)
if correspondingPeer == nil { if correspondingPeer == nil || err != nil {
logging.Log.WriteErrorf("node %s does not exist", node) logging.Log.WriteErrorf("node %s does not exist", node)
continue
} }
err := s.requester.SyncMesh(meshId, correspondingPeer) err = s.requester.SyncMesh(correspondingMesh, correspondingPeer)
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
succeeded = true succeeded = true
} else {
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated
// itself
s.manager.GetMesh(meshId).Mark(node)
} }
}
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++
}
s.manager.GetMesh(meshId).SaveChanges()
s.lastSync = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
} }
} }
s.syncCount++
logging.Log.WriteInfof("sync time: %v", time.Since(before))
logging.Log.WriteInfof("number of syncs: %d", s.syncCount)
s.infectionCount = ((s.configuration.InfectionCount + s.infectionCount - 1) % s.configuration.InfectionCount)
if !succeeded {
s.infectionCount++
}
correspondingMesh.SaveChanges()
s.lastSyncLock.Lock()
s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix()
s.lastSyncLock.Unlock()
return true, nil
}
// Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) (bool, error) {
peers := mesh.GetPeers()
pubKey, _ := self.GetPublicKey()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1)
if len(neighbour) == 0 {
logging.Log.WriteInfof("no neighbours")
return false, nil
}
logging.Log.WriteInfof("pulling from node %s", neighbour[0])
pullNode, err := mesh.GetNode(neighbour[0])
if err != nil || pullNode == nil {
return false, fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
}
err = s.requester.SyncMesh(mesh, pullNode)
if err == nil || err == io.EOF {
s.lastSync[mesh.GetMeshId()] = time.Now().Unix()
} else {
return false, err
}
s.syncCount++
changes := mesh.HasChanges()
return changes, nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
var wg sync.WaitGroup
meshes := s.meshManager.GetMeshes()
s.lastPollLock.Lock()
meshesToSync := lib.Filter(lib.MapValues(meshes), func(mesh mesh.MeshProvider) bool {
return time.Now().Unix()-s.lastPoll[mesh.GetMeshId()] >= int64(s.configuration.SyncInterval)
})
s.lastPollLock.Unlock()
changes := make(chan bool, len(meshesToSync))
for i := 0; i < len(meshesToSync); {
wg.Add(1)
sync := func(index int) {
defer wg.Done()
var hasChanges bool = false
mesh := meshesToSync[index]
hasChanges, err := s.Sync(mesh)
changes <- hasChanges
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
s.lastPollLock.Lock()
s.lastPoll[mesh.GetMeshId()] = time.Now().Unix()
s.lastPollLock.Unlock()
}
go sync(i)
i++
}
wg.Wait()
hasChanges := false
for i := 0; i < len(changes); i++ {
if <-changes {
hasChanges = true
}
}
var err error
err = s.meshManager.GetRouteManager().UpdateRoutes()
if err != nil {
logging.Log.WriteErrorf("update routes failed %s", err.Error())
}
if hasChanges {
logging.Log.WriteInfof("updating the WireGuard configuration")
err = s.meshManager.ApplyConfig()
if err != nil {
logging.Log.WriteErrorf("failed to update config %s", err.Error())
}
}
return nil return nil
} }
func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequester) Syncer { type NewSyncerParams struct {
cluster, _ := conn.NewConnCluster(conf.ClusterSize) MeshManager mesh.MeshManager
ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
Requester SyncRequester
}
func NewSyncer(params *NewSyncerParams) Syncer {
cluster, _ := conn.NewConnCluster(params.Configuration.ClusterSize)
syncRequester := NewSyncRequester(NewSyncRequesterParams{
MeshManager: params.MeshManager,
ConnectionManager: params.ConnectionManager,
Configuration: params.Configuration,
})
return &SyncerImpl{ return &SyncerImpl{
manager: m, meshManager: params.MeshManager,
conf: conf, configuration: params.Configuration,
requester: r, requester: syncRequester,
infectionCount: 0, infectionCount: 0,
syncCount: 0, syncCount: 0,
cluster: cluster} cluster: cluster,
lastSync: make(map[string]int64),
lastPoll: make(map[string]int64)}
} }

View File

@ -1,41 +1,60 @@
package sync package sync
import ( import (
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// SyncErrorHandler: Handles errors when attempting to sync // SyncErrorHandler: Handles errors when attempting to sync
type SyncErrorHandler interface { type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool Handle(mesh mesh.MeshProvider, endpoint string, err error) bool
} }
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler // SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct { type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager meshManager mesh.MeshManager
connManager conn.ConnectionManager
} }
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { func (s *SyncErrorHandlerImpl) handleFailed(mesh mesh.MeshProvider, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId) mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if err != nil {
s.connManager.RemoveConnection(node.GetHostEndpoint())
}
return true return true
} }
func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) bool { func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(mesh mesh.MeshProvider, nodeId string) bool {
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true
}
func (s *SyncErrorHandlerImpl) Handle(mesh mesh.MeshProvider, nodeId string, err error) bool {
errStatus, _ := status.FromError(err) errStatus, _ := status.FromError(err)
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.handleFailed(meshId, nodeId) return s.handleFailed(mesh, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(mesh, nodeId)
} }
return false return false
} }
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler { func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m} return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
} }

View File

@ -2,76 +2,44 @@ package sync
import ( import (
"context" "context"
"errors"
"io" "io"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
type SyncRequester interface { type SyncRequester interface {
GetMesh(meshId string, ifName string, port int, endPoint string) error SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error
SyncMesh(meshid string, meshNode mesh.MeshNode) error
} }
type SyncRequesterImpl struct { type SyncRequesterImpl struct {
server *ctrlserver.MeshCtrlServer manager mesh.MeshManager
connectionManager conn.ConnectionManager
configuration *conf.DaemonConfiguration
errorHdlr SyncErrorHandler errorHdlr SyncErrorHandler
} }
// GetMesh: Retrieves the local state of the mesh at the endpoint // handleErr: handleGrpc errors
func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error { func (s *SyncRequesterImpl) handleErr(mesh mesh.MeshProvider, pubKey string, err error) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) ok := s.errorHdlr.Handle(mesh, pubKey, err)
if err != nil {
return err
}
client, err := peerConnection.GetClient()
if err != nil {
return err
}
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
if err != nil {
return err
}
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
WgPort: port,
MeshBytes: reply.Mesh,
})
return err
}
func (s *SyncRequesterImpl) handleErr(meshId, pubKey string, err error) error {
ok := s.errorHdlr.Handle(meshId, pubKey, err)
if ok { if ok {
return nil return nil
} }
return err return err
} }
// SyncMesh: Proactively send a sync request to the other mesh // SyncMesh: Proactively send a sync request to the other mesh
func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) error { func (s *SyncRequesterImpl) SyncMesh(mesh mesh.MeshProvider, meshNode mesh.MeshNode) error {
endpoint := meshNode.GetHostEndpoint() endpoint := meshNode.GetHostEndpoint()
pubKey, _ := meshNode.GetPublicKey() pubKey, _ := meshNode.GetPublicKey()
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) peerConnection, err := s.connectionManager.GetConnection(endpoint)
if err != nil { if err != nil {
return err return err
@ -83,15 +51,9 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
return err return err
} }
mesh := s.server.MeshManager.GetMesh(meshId)
if mesh == nil {
return errors.New("mesh does not exist")
}
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second) syncTimeOut := float64(s.configuration.SyncInterval) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel() defer cancel()
@ -99,11 +61,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
err = s.syncMesh(mesh, ctx, c) err = s.syncMesh(mesh, ctx, c)
if err != nil { if err != nil {
return s.handleErr(meshId, pubKey.String(), err) s.handleErr(mesh, pubKey.String(), err)
} }
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("synced with node: %s meshId: %s\n", endpoint, mesh.GetMeshId())
return nil return err
} }
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error { func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
@ -127,7 +89,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
in, err := stream.Recv() in, err := stream.Recv()
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logging.Log.WriteInfof("Stream recv error: %s\n", err.Error()) logging.Log.WriteInfof("stream recv error: %s\n", err.Error())
return err return err
} }
@ -136,7 +98,7 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
} }
if err != nil { if err != nil {
logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error()) logging.Log.WriteInfof("syncer recv error: %s\n", err.Error())
return err return err
} }
@ -150,7 +112,17 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
return nil return nil
} }
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { type NewSyncRequesterParams struct {
errorHdlr := NewSyncErrorHandler(s.MeshManager) MeshManager mesh.MeshManager
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr} ConnectionManager conn.ConnectionManager
Configuration *conf.DaemonConfiguration
}
func NewSyncRequester(params NewSyncRequesterParams) SyncRequester {
errorHdlr := NewSyncErrorHandler(params.MeshManager, params.ConnectionManager)
return &SyncRequesterImpl{manager: params.MeshManager,
connectionManager: params.ConnectionManager,
configuration: params.Configuration,
errorHdlr: errorHdlr,
}
} }

View File

@ -1,18 +0,0 @@
package sync
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
syncer.SyncMeshes()
return nil
}
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate)
}

View File

@ -6,19 +6,18 @@ import (
"errors" "errors"
"io" "io"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/rpc"
) )
type SyncServiceImpl struct { type SyncServiceImpl struct {
rpc.UnimplementedSyncServiceServer rpc.UnimplementedSyncServiceServer
Server *ctrlserver.MeshCtrlServer MeshManager mesh.MeshManager
} }
// GetMesh: Gets a nodes local mesh configuration as a CRDT // GetMesh: Gets a nodes local mesh configuration as a CRDT
func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) { func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfRequest) (*rpc.GetConfReply, error) {
mesh := s.Server.MeshManager.GetMesh(request.MeshId) mesh := s.MeshManager.GetMesh(request.MeshId)
if mesh == nil { if mesh == nil {
return nil, errors.New("mesh does not exist") return nil, errors.New("mesh does not exist")
@ -56,7 +55,7 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
if len(meshId) == 0 { if len(meshId) == 0 {
meshId = in.MeshId meshId = in.MeshId
mesh := s.Server.MeshManager.GetMesh(meshId) mesh := s.MeshManager.GetMesh(meshId)
if mesh == nil { if mesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
@ -92,7 +91,3 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
} }
} }
} }
func NewSyncService(server *ctrlserver.MeshCtrlServer) *SyncServiceImpl {
return &SyncServiceImpl{Server: server}
}

View File

@ -1,15 +0,0 @@
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}

View File

@ -1,15 +1,20 @@
package wg package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) { // CreateInterface creates a WireGuard interface
return "", nil func (w *WgInterfaceManipulatorStub) CreateInterface(port int, privateKey *wgtypes.Key) (string, error) {
return "aninterface", nil
} }
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error { // AddAddress adds an address to the given interface name
func (w *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil return nil
} }
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error { // RemoveInterface removes the specified interface
func (w *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil return nil
} }

View File

@ -2,14 +2,6 @@ package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes" import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct {
msg string
}
func (m *WgError) Error() string {
return m.msg
}
type WgInterfaceManipulator interface { type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
CreateInterface(port int, privateKey *wgtypes.Key) (string, error) CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
@ -18,3 +10,11 @@ type WgInterfaceManipulator interface {
// RemoveInterface removes the specified interface // RemoveInterface removes the specified interface
RemoveInterface(ifName string) error RemoveInterface(ifName string) error
} }
type WgError struct {
msg string
}
func (m *WgError) Error() string {
return m.msg
}

View File

@ -5,8 +5,8 @@ import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )

@ -1 +0,0 @@
Subproject commit c1128bcd98a6ce4a04d4fe55c210d115d564419a