send/recv: job-level bandwidth limiting

Sponsored-by: Prominic.NET, Inc.

fixes #339
This commit is contained in:
Christian Schwarz 2021-07-09 16:30:44 +02:00
parent 5b16769057
commit f5f269bfd5
13 changed files with 427 additions and 9 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/zrepl/yaml-config" "github.com/zrepl/yaml-config"
"github.com/zrepl/zrepl/util/datasizeunit"
zfsprop "github.com/zrepl/zrepl/zfs/property" zfsprop "github.com/zrepl/zrepl/zfs/property"
) )
@ -87,6 +88,8 @@ type SendOptions struct {
Compressed bool `yaml:"compressed,optional,default=false"` Compressed bool `yaml:"compressed,optional,default=false"`
EmbeddedData bool `yaml:"embbeded_data,optional,default=false"` EmbeddedData bool `yaml:"embbeded_data,optional,default=false"`
Saved bool `yaml:"saved,optional,default=false"` Saved bool `yaml:"saved,optional,default=false"`
BandwidthLimit *BandwidthLimit `yaml:"bandwidth_limit,optional,fromdefaults"`
} }
type RecvOptions struct { type RecvOptions struct {
@ -96,6 +99,15 @@ type RecvOptions struct {
// Reencrypt bool `yaml:"reencrypt"` // Reencrypt bool `yaml:"reencrypt"`
Properties *PropertyRecvOptions `yaml:"properties,fromdefaults"` Properties *PropertyRecvOptions `yaml:"properties,fromdefaults"`
BandwidthLimit *BandwidthLimit `yaml:"bandwidth_limit,optional,fromdefaults"`
}
var _ yaml.Unmarshaler = &datasizeunit.Bits{}
type BandwidthLimit struct {
Max datasizeunit.Bits `yaml:"max,default=-1 B"`
BucketCapacity datasizeunit.Bits `yaml:"bucket_capacity,default=128 KiB"`
} }
type Replication struct { type Replication struct {
@ -113,10 +125,6 @@ type ReplicationOptionsConcurrency struct {
SizeEstimates int `yaml:"size_estimates,optional,default=4"` SizeEstimates int `yaml:"size_estimates,optional,default=4"`
} }
func (l *RecvOptions) SetDefault() {
*l = RecvOptions{Properties: &PropertyRecvOptions{}}
}
type PropertyRecvOptions struct { type PropertyRecvOptions struct {
Inherit []zfsprop.Property `yaml:"inherit,optional"` Inherit []zfsprop.Property `yaml:"inherit,optional"`
Override map[zfsprop.Property]string `yaml:"override,optional"` Override map[zfsprop.Property]string `yaml:"override,optional"`

View File

@ -0,0 +1,41 @@
jobs:
- type: sink
name: "limited_sink"
root_fs: "fs0"
recv:
bandwidth_limit:
max: 12345 B
serve:
type: local
listener_name: localsink
- type: push
name: "limited_push"
connect:
type: local
listener_name: localsink
client_identity: local_backup
filesystems: {
"root<": true,
}
send:
bandwidth_limit:
max: 54321 B
bucket_capacity: 1024 B
snapshotting:
type: manual
pruning:
keep_sender:
- type: last_n
count: 1
keep_receiver:
- type: last_n
count: 1
- type: sink
name: "nolimit_sink"
root_fs: "fs1"
serve:
type: local
listener_name: localsink

View File

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/util/bandwidthlimit"
) )
func JobsFromConfig(c *config.Config) ([]Job, error) { func JobsFromConfig(c *config.Config) ([]Job, error) {
@ -107,3 +108,13 @@ func validateReceivingSidesDoNotOverlap(receivingRootFSs []string) error {
} }
return nil return nil
} }
func buildBandwidthLimitConfig(in *config.BandwidthLimit) (c bandwidthlimit.Config, _ error) {
if in.Max.ToBytes() > 0 && int64(in.Max.ToBytes()) == 0 {
return c, fmt.Errorf("bandwidth limit `max` is too small, must at least specify one byte")
}
return bandwidthlimit.Config{
Max: int64(in.Max.ToBytes()),
BucketCapacity: int64(in.BucketCapacity.ToBytes()),
}, nil
}

View File

@ -22,7 +22,12 @@ func buildSenderConfig(in SendingJobConfig, jobID endpoint.JobID) (*endpoint.Sen
return nil, errors.Wrap(err, "cannot build filesystem filter") return nil, errors.Wrap(err, "cannot build filesystem filter")
} }
sendOpts := in.GetSendOptions() sendOpts := in.GetSendOptions()
return &endpoint.SenderConfig{ bwlim, err := buildBandwidthLimitConfig(sendOpts.BandwidthLimit)
if err != nil {
return nil, errors.Wrap(err, "cannot build bandwith limit config")
}
sc := &endpoint.SenderConfig{
FSF: fsf, FSF: fsf,
JobID: jobID, JobID: jobID,
@ -34,7 +39,15 @@ func buildSenderConfig(in SendingJobConfig, jobID endpoint.JobID) (*endpoint.Sen
SendCompressed: sendOpts.Compressed, SendCompressed: sendOpts.Compressed,
SendEmbeddedData: sendOpts.EmbeddedData, SendEmbeddedData: sendOpts.EmbeddedData,
SendSaved: sendOpts.Saved, SendSaved: sendOpts.Saved,
}, nil
BandwidthLimit: bwlim,
}
if err := sc.Validate(); err != nil {
return nil, errors.Wrap(err, "cannot build sender config")
}
return sc, nil
} }
type ReceivingJobConfig interface { type ReceivingJobConfig interface {
@ -53,6 +66,12 @@ func buildReceiverConfig(in ReceivingJobConfig, jobID endpoint.JobID) (rc endpoi
} }
recvOpts := in.GetRecvOptions() recvOpts := in.GetRecvOptions()
bwlim, err := buildBandwidthLimitConfig(recvOpts.BandwidthLimit)
if err != nil {
return rc, errors.Wrap(err, "cannot build bandwith limit config")
}
rc = endpoint.ReceiverConfig{ rc = endpoint.ReceiverConfig{
JobID: jobID, JobID: jobID,
RootWithoutClientComponent: rootFs, RootWithoutClientComponent: rootFs,
@ -60,6 +79,8 @@ func buildReceiverConfig(in ReceivingJobConfig, jobID endpoint.JobID) (rc endpoi
InheritProperties: recvOpts.Properties.Inherit, InheritProperties: recvOpts.Properties.Inherit,
OverrideProperties: recvOpts.Properties.Override, OverrideProperties: recvOpts.Properties.Override,
BandwidthLimit: bwlim,
} }
if err := rc.Validate(); err != nil { if err := rc.Validate(); err != nil {
return rc, errors.Wrap(err, "cannot build receiver config") return rc, errors.Wrap(err, "cannot build receiver config")

View File

@ -119,6 +119,14 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) {
t.Errorf("glob failed: %+v", err) t.Errorf("glob failed: %+v", err)
} }
type additionalCheck struct {
state int
test func(t *testing.T, jobs []Job)
}
additionalChecks := map[string]*additionalCheck{
"bandwidth_limit.yml": {test: testSampleConfig_BandwidthLimit},
}
for _, p := range paths { for _, p := range paths {
if path.Ext(p) != ".yml" { if path.Ext(p) != ".yml" {
@ -126,10 +134,20 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) {
continue continue
} }
filename := path.Base(p)
t.Logf("checking for presence additonal checks for file %q", filename)
additionalCheck := additionalChecks[filename]
if additionalCheck == nil {
t.Logf("no additional checks")
} else {
t.Logf("additional check present")
additionalCheck.state = 1
}
t.Run(p, func(t *testing.T) { t.Run(p, func(t *testing.T) {
c, err := config.ParseConfig(p) c, err := config.ParseConfig(p)
if err != nil { if err != nil {
t.Errorf("error parsing %s:\n%+v", p, err) t.Fatalf("error parsing %s:\n%+v", p, err)
} }
t.Logf("file: %s", p) t.Logf("file: %s", p)
@ -138,11 +156,57 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) {
tls.FakeCertificateLoading(t) tls.FakeCertificateLoading(t)
jobs, err := JobsFromConfig(c) jobs, err := JobsFromConfig(c)
t.Logf("jobs: %#v", jobs) t.Logf("jobs: %#v", jobs)
assert.NoError(t, err) require.NoError(t, err)
if additionalCheck != nil {
additionalCheck.test(t, jobs)
additionalCheck.state = 2
}
}) })
} }
for basename, c := range additionalChecks {
if c.state == 0 {
panic("univisited additional check " + basename)
}
}
}
func testSampleConfig_BandwidthLimit(t *testing.T, jobs []Job) {
require.Len(t, jobs, 3)
{
limitedSink, ok := jobs[0].(*PassiveSide)
require.True(t, ok, "%T", jobs[0])
limitedSinkMode, ok := limitedSink.mode.(*modeSink)
require.True(t, ok, "%T", limitedSink)
assert.Equal(t, int64(12345), limitedSinkMode.receiverConfig.BandwidthLimit.Max)
assert.Equal(t, int64(1<<17), limitedSinkMode.receiverConfig.BandwidthLimit.BucketCapacity)
}
{
limitedPush, ok := jobs[1].(*ActiveSide)
require.True(t, ok, "%T", jobs[1])
limitedPushMode, ok := limitedPush.mode.(*modePush)
require.True(t, ok, "%T", limitedPush)
assert.Equal(t, int64(54321), limitedPushMode.senderConfig.BandwidthLimit.Max)
assert.Equal(t, int64(1024), limitedPushMode.senderConfig.BandwidthLimit.BucketCapacity)
}
{
unlimitedSink, ok := jobs[2].(*PassiveSide)
require.True(t, ok, "%T", jobs[2])
unlimitedSinkMode, ok := unlimitedSink.mode.(*modeSink)
require.True(t, ok, "%T", unlimitedSink)
max := unlimitedSinkMode.receiverConfig.BandwidthLimit.Max
assert.Less(t, max, int64(0), max, "unlimited mode <=> negative value for .Max, see bandwidthlimit.Config")
}
} }
func TestReplicationOptions(t *testing.T) { func TestReplicationOptions(t *testing.T) {

View File

@ -36,6 +36,9 @@ See the `upstream man page <https://openzfs.github.io/openzfs-docs/man/8/zfs-sen
* - ``encrypted`` * - ``encrypted``
- -
- Specific to zrepl, :ref:`see below <job-send-options-encrypted>`. - Specific to zrepl, :ref:`see below <job-send-options-encrypted>`.
* - ``bandwidth_limit``
-
- Specific to zrepl, :ref:`see below <job-send-recv-options-bandwidth-limit>`.
* - ``raw`` * - ``raw``
- ``-w`` - ``-w``
- Use ``encrypted`` to only allow encrypted sends. - Use ``encrypted`` to only allow encrypted sends.
@ -138,6 +141,7 @@ Recv Options
override: { override: {
"org.openzfs.systemd:ignore": "on" "org.openzfs.systemd:ignore": "on"
} }
bandwidth_limit: ... # see below
... ...
.. _job-recv-options--inherit-and-override: .. _job-recv-options--inherit-and-override:
@ -212,3 +216,25 @@ and property replication is enabled, the receiver must :ref:`inherit the followi
* ``keylocation`` * ``keylocation``
* ``keyformat`` * ``keyformat``
* ``encryption`` * ``encryption``
Common Options
~~~~~~~~~~~~~~
.. _job-send-recv-options-bandwidth-limit:
Bandwidth Limit (send & recv)
-----------------------------
::
bandwidth_limit:
max: 23.5 MiB # -1 is the default and disabled rate limiting
bucket_capacity: # token bucket capacity in bytes; defaults to 128KiB
Both ``send`` and ``recv`` can be limited to a maximum bandwidth through ``bandwidth_limit``.
For most users, it should be sufficient to just set ``bandwidth_limit.max``.
The ``bandwidth_limit.bucket_capacity`` refers to the `token bucket size <https://github.com/juju/ratelimit>`_.
The bandwidth limit only applies to the payload data, i.e., the ZFS send stream.
It does not account for transport protocol overheads.
The scope is the job level, i.e., all :ref:`concurrent <replication-option-concurrency>` sends or incoming receives of a job share the bandwidth limit.

View File

@ -32,6 +32,7 @@ We would like to thank the following people and organizations for supporting zre
<div class="fa fa-code" style="width: 1em;"></div> <div class="fa fa-code" style="width: 1em;"></div>
* |supporter-gold| Prominic.NET, Inc.
* |supporter-std| Torsten Blum * |supporter-std| Torsten Blum
* |supporter-gold| Cyberiada GmbH * |supporter-gold| Cyberiada GmbH
* |supporter-std| `Gordon Schulz <https://github.com/azmodude>`_ * |supporter-std| `Gordon Schulz <https://github.com/azmodude>`_

View File

@ -14,6 +14,7 @@ import (
"github.com/zrepl/zrepl/daemon/logging/trace" "github.com/zrepl/zrepl/daemon/logging/trace"
"github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/replication/logic/pdu"
"github.com/zrepl/zrepl/util/bandwidthlimit"
"github.com/zrepl/zrepl/util/chainedio" "github.com/zrepl/zrepl/util/chainedio"
"github.com/zrepl/zrepl/util/chainlock" "github.com/zrepl/zrepl/util/chainlock"
"github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/envconst"
@ -34,6 +35,8 @@ type SenderConfig struct {
SendCompressed bool SendCompressed bool
SendEmbeddedData bool SendEmbeddedData bool
SendSaved bool SendSaved bool
BandwidthLimit bandwidthlimit.Config
} }
func (c *SenderConfig) Validate() error { func (c *SenderConfig) Validate() error {
@ -44,6 +47,9 @@ func (c *SenderConfig) Validate() error {
if _, err := StepHoldTag(c.JobID); err != nil { if _, err := StepHoldTag(c.JobID); err != nil {
return fmt.Errorf("JobID cannot be used for hold tag: %s", err) return fmt.Errorf("JobID cannot be used for hold tag: %s", err)
} }
if err := bandwidthlimit.ValidateConfig(c.BandwidthLimit); err != nil {
return errors.Wrap(err, "`Ratelimit` field invalid")
}
return nil return nil
} }
@ -54,16 +60,21 @@ type Sender struct {
FSFilter zfs.DatasetFilter FSFilter zfs.DatasetFilter
jobId JobID jobId JobID
config SenderConfig config SenderConfig
bwLimit bandwidthlimit.Wrapper
} }
func NewSender(conf SenderConfig) *Sender { func NewSender(conf SenderConfig) *Sender {
if err := conf.Validate(); err != nil { if err := conf.Validate(); err != nil {
panic("invalid config" + err.Error()) panic("invalid config" + err.Error())
} }
ratelimiter := bandwidthlimit.WrapperFromConfig(conf.BandwidthLimit)
return &Sender{ return &Sender{
FSFilter: conf.FSF, FSFilter: conf.FSF,
jobId: conf.JobID, jobId: conf.JobID,
config: conf, config: conf,
bwLimit: ratelimiter,
} }
} }
@ -301,12 +312,16 @@ func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.Rea
abstractionsCacheSingleton.TryBatchDestroy(ctx, s.jobId, sendArgs.FS, destroyTypes, keep, check) abstractionsCacheSingleton.TryBatchDestroy(ctx, s.jobId, sendArgs.FS, destroyTypes, keep, check)
}() }()
sendStream, err := zfs.ZFSSend(ctx, sendArgs) var sendStream io.ReadCloser
sendStream, err = zfs.ZFSSend(ctx, sendArgs)
if err != nil { if err != nil {
// it's ok to not destroy the abstractions we just created here, a new send attempt will take care of it // it's ok to not destroy the abstractions we just created here, a new send attempt will take care of it
return nil, nil, errors.Wrap(err, "zfs send failed") return nil, nil, errors.Wrap(err, "zfs send failed")
} }
// apply rate limit
sendStream = s.bwLimit.WrapReadCloser(sendStream)
return res, sendStream, nil return res, sendStream, nil
} }
@ -439,6 +454,8 @@ type ReceiverConfig struct {
InheritProperties []zfsprop.Property InheritProperties []zfsprop.Property
OverrideProperties map[zfsprop.Property]string OverrideProperties map[zfsprop.Property]string
BandwidthLimit bandwidthlimit.Config
} }
func (c *ReceiverConfig) copyIn() { func (c *ReceiverConfig) copyIn() {
@ -475,6 +492,11 @@ func (c *ReceiverConfig) Validate() error {
if c.RootWithoutClientComponent.Length() <= 0 { if c.RootWithoutClientComponent.Length() <= 0 {
return errors.New("RootWithoutClientComponent must not be an empty dataset path") return errors.New("RootWithoutClientComponent must not be an empty dataset path")
} }
if err := bandwidthlimit.ValidateConfig(c.BandwidthLimit); err != nil {
return errors.Wrap(err, "`Ratelimit` field invalid")
}
return nil return nil
} }
@ -484,6 +506,8 @@ type Receiver struct {
conf ReceiverConfig // validated conf ReceiverConfig // validated
bwLimit bandwidthlimit.Wrapper
recvParentCreationMtx *chainlock.L recvParentCreationMtx *chainlock.L
} }
@ -495,6 +519,7 @@ func NewReceiver(config ReceiverConfig) *Receiver {
return &Receiver{ return &Receiver{
conf: config, conf: config,
recvParentCreationMtx: chainlock.New(), recvParentCreationMtx: chainlock.New(),
bwLimit: bandwidthlimit.WrapperFromConfig(config.BandwidthLimit),
} }
} }
@ -787,6 +812,9 @@ func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.
return nil, errors.Wrap(err, "cannot determine whether we can use resumable send & recv") return nil, errors.Wrap(err, "cannot determine whether we can use resumable send & recv")
} }
// apply rate limit
receive = s.bwLimit.WrapReadCloser(receive)
var peek bytes.Buffer var peek bytes.Buffer
var MaxPeek = envconst.Int64("ZREPL_ENDPOINT_RECV_PEEK_SIZE", 1<<20) var MaxPeek = envconst.Int64("ZREPL_ENDPOINT_RECV_PEEK_SIZE", 1<<20)
log.WithField("max_peek_bytes", MaxPeek).Info("peeking incoming stream") log.WithField("max_peek_bytes", MaxPeek).Info("peeking incoming stream")

1
go.mod
View File

@ -15,6 +15,7 @@ require (
github.com/golang/protobuf v1.4.3 github.com/golang/protobuf v1.4.3
github.com/google/uuid v1.1.2 github.com/google/uuid v1.1.2
github.com/jinzhu/copier v0.0.0-20170922082739-db4671f3a9b8 github.com/jinzhu/copier v0.0.0-20170922082739-db4671f3a9b8
github.com/juju/ratelimit v1.0.1
github.com/kisielk/gotool v1.0.0 // indirect github.com/kisielk/gotool v1.0.0 // indirect
github.com/kr/pretty v0.1.0 github.com/kr/pretty v0.1.0
github.com/leodido/go-urn v1.2.1 // indirect github.com/leodido/go-urn v1.2.1 // indirect

2
go.sum
View File

@ -164,6 +164,8 @@ github.com/jinzhu/copier v0.0.0-20170922082739-db4671f3a9b8 h1:+dKzeuiDYbD/Cfi/s
github.com/jinzhu/copier v0.0.0-20170922082739-db4671f3a9b8/go.mod h1:yL958EeXv8Ylng6IfnvG4oflryUi3vgA3xPs9hmII1s= github.com/jinzhu/copier v0.0.0-20170922082739-db4671f3a9b8/go.mod h1:yL958EeXv8Ylng6IfnvG4oflryUi3vgA3xPs9hmII1s=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY=
github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM=

View File

@ -0,0 +1,72 @@
package bandwidthlimit
import (
"errors"
"io"
"github.com/juju/ratelimit"
)
type Wrapper interface {
WrapReadCloser(io.ReadCloser) io.ReadCloser
}
type Config struct {
// Units in this struct are in _bytes_.
Max int64 // < 0 means no limit, BucketCapacity is irrelevant then
BucketCapacity int64
}
func ValidateConfig(conf Config) error {
if conf.BucketCapacity == 0 {
return errors.New("BucketCapacity must not be zero")
}
return nil
}
func WrapperFromConfig(conf Config) Wrapper {
if err := ValidateConfig(conf); err != nil {
panic(err)
}
if conf.Max < 0 {
return noLimit{}
}
return &withLimit{
bucket: ratelimit.NewBucketWithRate(float64(conf.Max), conf.BucketCapacity),
}
}
type noLimit struct{}
func (_ noLimit) WrapReadCloser(rc io.ReadCloser) io.ReadCloser { return rc }
type withLimit struct {
bucket *ratelimit.Bucket
}
func (l *withLimit) WrapReadCloser(rc io.ReadCloser) io.ReadCloser {
return WrapReadCloser(rc, l.bucket)
}
type withLimitReadCloser struct {
orig io.Closer
limited io.Reader
}
func (r *withLimitReadCloser) Read(buf []byte) (int, error) {
return r.limited.Read(buf)
}
func (r *withLimitReadCloser) Close() error {
return r.orig.Close()
}
func WrapReadCloser(rc io.ReadCloser, bucket *ratelimit.Bucket) io.ReadCloser {
return &withLimitReadCloser{
limited: ratelimit.Reader(rc, bucket),
orig: rc,
}
}

View File

@ -0,0 +1,86 @@
package datasizeunit
import (
"errors"
"fmt"
"math"
"regexp"
"strconv"
"strings"
)
type Bits struct {
bits float64
}
func (b Bits) ToBits() float64 { return b.bits }
func (b Bits) ToBytes() float64 { return b.bits / 8 }
func FromBytesInt64(i int64) Bits { return Bits{float64(i) * 8} }
var datarateRegex = regexp.MustCompile(`^([-0-9\.]*)\s*(bit|(|K|Ki|M|Mi|G|Gi|T|Ti)([bB]))$`)
func (r *Bits) UnmarshalYAML(u func(interface{}, bool) error) (_ error) {
var s string
err := u(&s, false)
if err != nil {
return err
}
genericErr := func(err error) error {
var buf strings.Builder
fmt.Fprintf(&buf, "cannot parse %q using regex %s", s, datarateRegex)
if err != nil {
fmt.Fprintf(&buf, ": %s", err)
}
return errors.New(buf.String())
}
match := datarateRegex.FindStringSubmatch(s)
if match == nil {
return genericErr(nil)
}
bps, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return genericErr(err)
}
if match[2] == "bit" {
if math.Round(bps) != bps {
return genericErr(fmt.Errorf("unit bit must be an integer value"))
}
r.bits = bps
return nil
}
factorMap := map[string]uint64{
"": 1,
"K": 1e3,
"M": 1e6,
"G": 1e9,
"T": 1e12,
"Ki": 1 << 10,
"Mi": 1 << 20,
"Gi": 1 << 30,
"Ti": 1 << 40,
}
factor, ok := factorMap[match[3]]
if !ok {
panic(match)
}
baseUnitFactorMap := map[string]uint64{
"b": 1,
"B": 8,
}
baseUnitFactor, ok := baseUnitFactorMap[match[4]]
if !ok {
panic(match)
}
r.bits = bps * float64(factor) * float64(baseUnitFactor)
return nil
}

View File

@ -0,0 +1,57 @@
package datasizeunit
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zrepl/yaml-config"
)
func TestBits(t *testing.T) {
tcs := []struct {
input string
expectRate float64
expectErr string
}{
{`23 bit`, 23, ""}, // bit special case works
{`23bit`, 23, ""}, // also without space
{`10MiB`, 10 * (1 << 20) * 8, ""}, // integer unit without space
{`10 MiB`, 8 * 10 * (1 << 20), ""}, // integer unit with space
{`10.5 Kib`, 10.5 * (1 << 10), ""}, // floating point with bit unit works with space
{`10.5Kib`, 10.5 * (1 << 10), ""}, // floating point with bit unit works without space
// unit checks
{`1 bit`, 1, ""},
{`1 B`, 1 * 8, ""},
{`1 Kb`, 1e3, ""},
{`1 Kib`, 1 << 10, ""},
{`1 Mb`, 1e6, ""},
{`1 Mib`, 1 << 20, ""},
{`1 Gb`, 1e9, ""},
{`1 Gib`, 1 << 30, ""},
{`1 Tb`, 1e12, ""},
{`1 Tib`, 1 << 40, ""},
}
for _, tc := range tcs {
t.Run(tc.input, func(t *testing.T) {
var bits Bits
err := yaml.Unmarshal([]byte(tc.input), &bits)
if tc.expectErr != "" {
assert.Error(t, err)
assert.Regexp(t, tc.expectErr, err.Error())
assert.Zero(t, bits.bits)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectRate, bits.bits)
}
})
}
}