feature: add update setup key endpoint

This commit is contained in:
braginini 2021-08-20 22:33:43 +02:00
parent 617f79e2e0
commit 2e9fc20567
12 changed files with 315 additions and 32 deletions

1
go.mod
View File

@ -7,6 +7,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang/protobuf v1.5.2
github.com/google/uuid v1.2.0
github.com/gorilla/mux v1.8.0
github.com/kardianos/service v1.2.0
github.com/onsi/ginkgo v1.16.4
github.com/onsi/gomega v1.13.0

2
go.sum
View File

@ -100,6 +100,8 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=

View File

@ -8,6 +8,7 @@ import (
"net"
"strings"
"sync"
"time"
)
type AccountManager struct {
@ -77,6 +78,79 @@ func (manager *AccountManager) GetPeersForAPeer(peerKey string) ([]*Peer, error)
return res, nil
}
//AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
func (manager *AccountManager) AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn time.Duration) (*SetupKey, error) {
manager.mux.Lock()
defer manager.mux.Unlock()
account, err := manager.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := GenerateSetupKey(keyName, keyType, expiresIn)
account.SetupKeys[setupKey.Key] = setupKey
err = manager.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return setupKey, nil
}
//RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
func (manager *AccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
manager.mux.Lock()
defer manager.mux.Unlock()
account, err := manager.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Revoked = true
account.SetupKeys[keyCopy.Key] = keyCopy
err = manager.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
//RenameSetupKey renames existing setup key of the specified account.
func (manager *AccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) {
manager.mux.Lock()
defer manager.mux.Unlock()
account, err := manager.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Name = newName
account.SetupKeys[keyCopy.Key] = keyCopy
err = manager.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
//GetAccount returns an existing account or error (NotFound) if doesn't exist
func (manager *AccountManager) GetAccount(accountId string) (*Account, error) {
manager.mux.Lock()
@ -177,13 +251,7 @@ func (manager *AccountManager) AddPeer(setupKey string, peerKey string) (*Peer,
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", upperKey)
}
for _, key := range account.SetupKeys {
if upperKey == key.Key {
sk = key
break
}
}
sk = getAccountSetupKeyByKey(account, setupKey)
if sk == nil {
// shouldn't happen actually
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", upperKey)
@ -242,3 +310,21 @@ func newAccount() (*Account, *SetupKey) {
accountId := uuid.New().String()
return newAccountWithId(accountId)
}
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
for _, k := range acc.SetupKeys {
if keyId == k.Id {
return k
}
}
return nil
}
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
for _, k := range acc.SetupKeys {
if key == k.Key {
return k
}
}
return nil
}

View File

@ -29,7 +29,7 @@ func TestAccountManager_AddAccount(t *testing.T) {
}
if account.Id != expectedId {
t.Errorf("expected account to have ID = %s, got %s", expectedId, account.Id)
t.Errorf("expected account to have Id = %s, got %s", expectedId, account.Id)
}
if len(account.Peers) != expectedPeersSize {
@ -130,7 +130,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
if account.Id != getAccount.Id {
t.Errorf("expected account.ID %s, got %s", account.Id, getAccount.Id)
t.Errorf("expected account.Id %s, got %s", account.Id, getAccount.Id)
}
for _, peer := range account.Peers {

View File

@ -28,7 +28,7 @@ func NewPeers(accountManager *server.AccountManager) *Peers {
}
}
func (h *Peers) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
accountId := extractAccountIdFromRequestContext(r)

View File

@ -2,8 +2,11 @@ package handler
import (
"encoding/json"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/management/server"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"time"
)
@ -15,11 +18,21 @@ type SetupKeys struct {
// SetupKeyResponse is a response sent to the client
type SetupKeyResponse struct {
Id string
Key string
Name string
Expires time.Time
Type string
Type server.SetupKeyType
Valid bool
Revoked bool
}
// SetupKeyRequest is a request sent by client. This object contains fields that can be modified
type SetupKeyRequest struct {
Name string
Type server.SetupKeyType
ExpiresIn Duration
Revoked bool
}
func NewSetupKeysHandler(accountManager *server.AccountManager) *SetupKeys {
@ -28,7 +41,90 @@ func NewSetupKeysHandler(accountManager *server.AccountManager) *SetupKeys {
}
}
func (h *SetupKeys) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *SetupKeys) CreateKey(w http.ResponseWriter, r *http.Request) {
accountId := extractAccountIdFromRequestContext(r)
req := &SetupKeyRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, req.Type, req.ExpiresIn.Duration)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, "account not found", http.StatusNotFound)
return
}
http.Error(w, "failed adding setup key", http.StatusInternalServerError)
return
}
writeSuccess(w, setupKey)
}
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
accountId := extractAccountIdFromRequestContext(r)
vars := mux.Vars(r)
keyId := vars["id"]
if len(keyId) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodPost:
req := &SetupKeyRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var key *server.SetupKey
if req.Revoked {
//handle only if being revoked, don't allow to enable key again for now
key, err = h.accountManager.RevokeSetupKey(accountId, keyId)
if err != nil {
http.Error(w, "failed revoking key", http.StatusInternalServerError)
return
}
}
if len(req.Name) != 0 {
key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name)
if err != nil {
http.Error(w, "failed renaming key", http.StatusInternalServerError)
return
}
}
if key != nil {
writeSuccess(w, key)
}
return
case http.MethodGet:
account, err := h.accountManager.GetAccount(accountId)
if err != nil {
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
return
}
for _, key := range account.SetupKeys {
if key.Id == keyId {
writeSuccess(w, key)
return
}
}
http.Error(w, "setup key not found", http.StatusNotFound)
return
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
accountId := extractAccountIdFromRequestContext(r)
@ -44,13 +140,7 @@ func (h *SetupKeys) ServeHTTP(w http.ResponseWriter, r *http.Request) {
respBody := []*SetupKeyResponse{}
for _, key := range account.SetupKeys {
respBody = append(respBody, &SetupKeyResponse{
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
})
respBody = append(respBody, toResponseBody(key))
}
err = json.NewEncoder(w).Encode(respBody)
@ -63,3 +153,25 @@ func (h *SetupKeys) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusNotFound)
}
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
return
}
}
func toResponseBody(key *server.SetupKey) *SetupKeyResponse {
return &SetupKeyResponse{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: key.Type,
Valid: key.IsValid(),
Revoked: key.Revoked,
}
}

View File

@ -1,8 +1,11 @@
package handler
import (
"encoding/json"
"errors"
"github.com/golang-jwt/jwt"
"net/http"
"time"
)
// extractAccountIdFromRequestContext extracts accountId from the request context previously filled by the JWT token (after auth)
@ -13,3 +16,33 @@ func extractAccountIdFromRequestContext(r *http.Request) string {
//actually a user id but for now we have a 1 to 1 mapping.
return claims["sub"].(string)
}
//Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}

View File

@ -2,6 +2,7 @@ package http
import (
"context"
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
s "github.com/wiretrustee/wiretrustee/management/server"
@ -54,20 +55,24 @@ func (s *Server) Start() error {
}
corsMiddleware := cors.AllowAll()
h := http.NewServeMux()
s.server.Handler = h
r := mux.NewRouter()
r.Use(jwtMiddleware.Handler, corsMiddleware.Handler)
peersHandler := handler.NewPeers(s.accountManager)
keysHandler := handler.NewSetupKeysHandler(s.accountManager)
h.Handle("/api/peers", corsMiddleware.Handler(jwtMiddleware.Handler(peersHandler)))
h.Handle("/api/setup-keys", corsMiddleware.Handler(jwtMiddleware.Handler(keysHandler)))
http.Handle("/", h)
r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "OPTIONS")
r.HandleFunc("/api/setup-keys", keysHandler.CreateKey).Methods("PUT")
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "POST", "OPTIONS")
http.Handle("/", r)
if s.certManager != nil {
// if HTTPS is enabled we reuse the listener from the cert manager
listener := s.certManager.Listener()
log.Infof("http server listening on %s", listener.Addr())
if err = http.Serve(listener, s.certManager.HTTPHandler(h)); err != nil {
if err = http.Serve(listener, s.certManager.HTTPHandler(r)); err != nil {
log.Errorf("failed to serve https server: %v", err)
return err
}

View File

@ -2,6 +2,8 @@ package server
import (
"github.com/google/uuid"
"hash/fnv"
"strconv"
"strings"
"time"
)
@ -23,6 +25,7 @@ type SetupKeyType string
// SetupKey represents a pre-authorized key used to register machines (peers)
type SetupKey struct {
Id string
Key string
Name string
Type SetupKeyType
@ -34,6 +37,20 @@ type SetupKey struct {
UsedTimes int
}
//Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
return &SetupKey{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
}
}
// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to
func (key *SetupKey) IsValid() bool {
expired := time.Now().After(key.ExpiresAt)
@ -43,9 +60,11 @@ func (key *SetupKey) IsValid() bool {
// GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *SetupKey {
key := strings.ToUpper(uuid.New().String())
createdAt := time.Now()
return &SetupKey{
Key: strings.ToUpper(uuid.New().String()),
Id: strconv.Itoa(int(Hash(key))),
Key: key,
Name: name,
Type: t,
CreatedAt: createdAt,
@ -59,3 +78,12 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *Setu
func GenerateDefaultSetupKey() *SetupKey {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration)
}
func Hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

View File

@ -2,6 +2,7 @@ package server
import (
"github.com/google/uuid"
"strconv"
"testing"
"time"
)
@ -16,7 +17,8 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
key := GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
}
@ -30,7 +32,7 @@ func TestGenerateSetupKey(t *testing.T) {
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
}
@ -68,7 +70,8 @@ func TestSetupKey_IsValid(t *testing.T) {
}
}
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time) {
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string) {
if key.Name != expectedName {
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
}
@ -97,4 +100,17 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
if err != nil {
t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err)
}
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
}
}
func TestSetupKey_Copy(t *testing.T) {
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour)
keyCopy := key.Copy()
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id)
}

View File

@ -109,7 +109,7 @@ var _ = Describe("Client", func() {
})
})
Context("with a raw client and no ID header", func() {
Context("with a raw client and no Id header", func() {
It("should fail", func() {
client := createRawSignalClient(addr)
@ -125,7 +125,7 @@ var _ = Describe("Client", func() {
})
})
Context("with a raw client and with an ID header", func() {
Context("with a raw client and with an Id header", func() {
It("should be successful", func() {
md := metadata.New(map[string]string{sigProto.HeaderId: "peer"})

View File

@ -88,7 +88,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
}
// Handles initial Peer connection.
// Each connection must provide an ID header.
// Each connection must provide an Id header.
// At this moment the connecting Peer will be registered in the peer.Registry
func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {