diff --git a/cmd/zrok/davtest.go b/cmd/zrok/davtest.go new file mode 100644 index 00000000..78e919ba --- /dev/null +++ b/cmd/zrok/davtest.go @@ -0,0 +1,42 @@ +package main + +import ( + "context" + "github.com/openziti/zrok/util/sync/driveClient" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "net/http" +) + +func init() { + rootCmd.AddCommand(newDavtestCommand().cmd) +} + +type davtestCommand struct { + cmd *cobra.Command +} + +func newDavtestCommand() *davtestCommand { + cmd := &cobra.Command{ + Use: "davtest", + Short: "WebDAV testing wrapper", + Args: cobra.ExactArgs(1), + } + command := &davtestCommand{cmd: cmd} + cmd.Run = command.run + return command +} + +func (cmd *davtestCommand) run(_ *cobra.Command, args []string) { + client, err := driveClient.NewClient(http.DefaultClient, args[0]) + if err != nil { + panic(err) + } + fis, err := client.Readdir(context.Background(), "/", true) + if err != nil { + panic(err) + } + for _, fi := range fis { + logrus.Infof("=> %s", fi.Path) + } +} diff --git a/util/sync/driveClient/client.go b/util/sync/driveClient/client.go index 9de24654..77cd2215 100644 --- a/util/sync/driveClient/client.go +++ b/util/sync/driveClient/client.go @@ -1,32 +1,268 @@ package driveClient -import "net/http" +import ( + "context" + "fmt" + "github.com/openziti/zrok/util/sync/driveClient/internal" + "io" + "net/http" + "time" +) +// HTTPClient performs HTTP requests. It's implemented by *http.Client. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type basicAuthHTTPClient struct { + c HTTPClient + username, password string +} + +func (c *basicAuthHTTPClient) Do(req *http.Request) (*http.Response, error) { + req.SetBasicAuth(c.username, c.password) + return c.c.Do(req) +} + +// HTTPClientWithBasicAuth returns an HTTP client that adds basic +// authentication to all outgoing requests. If c is nil, http.DefaultClient is +// used. +func HTTPClientWithBasicAuth(c HTTPClient, username, password string) HTTPClient { + if c == nil { + c = http.DefaultClient + } + return &basicAuthHTTPClient{c, username, password} +} + +// Client provides access to a remote WebDAV filesystem. type Client struct { - client *http.Client + ic *internal.Client } -func NewHttpClient() *Client { - return &Client{&http.Client{}} +func NewClient(c HTTPClient, endpoint string) (*Client, error) { + ic, err := internal.NewClient(c, endpoint) + if err != nil { + return nil, err + } + return &Client{ic}, nil } -func (c *Client) Connect() error { - return nil +func (c *Client) FindCurrentUserPrincipal(ctx context.Context) (string, error) { + propfind := internal.NewPropNamePropFind(internal.CurrentUserPrincipalName) + + // TODO: consider retrying on the root URI "/" if this fails, as suggested + // by the RFC? + resp, err := c.ic.PropFindFlat(ctx, "", propfind) + if err != nil { + return "", err + } + + var prop internal.CurrentUserPrincipal + if err := resp.DecodeProp(&prop); err != nil { + return "", err + } + if prop.Unauthenticated != nil { + return "", fmt.Errorf("webdav: unauthenticated") + } + + return prop.Href.Path, nil } -func (c *Client) options(uri string) (*http.Response, error) { - return c.request("OPTIONS", uri) -} +var fileInfoPropFind = internal.NewPropNamePropFind( + internal.ResourceTypeName, + internal.GetContentLengthName, + internal.GetLastModifiedName, + internal.GetContentTypeName, + internal.GetETagName, +) -func (c *Client) request(method, uri string) (resp *http.Response, err error) { - req, err := http.NewRequest(method, uri, nil) +func fileInfoFromResponse(resp *internal.Response) (*FileInfo, error) { + path, err := resp.Path() if err != nil { return nil, err } - if resp, err = c.client.Do(req); err != nil { - return resp, err + fi := &FileInfo{Path: path} + + var resType internal.ResourceType + if err := resp.DecodeProp(&resType); err != nil { + return nil, err } - return resp, err + if resType.Is(internal.CollectionName) { + fi.IsDir = true + } else { + var getLen internal.GetContentLength + if err := resp.DecodeProp(&getLen); err != nil { + return nil, err + } + + var getType internal.GetContentType + if err := resp.DecodeProp(&getType); err != nil && !internal.IsNotFound(err) { + return nil, err + } + + var getETag internal.GetETag + if err := resp.DecodeProp(&getETag); err != nil && !internal.IsNotFound(err) { + return nil, err + } + + fi.Size = getLen.Length + fi.MIMEType = getType.Type + fi.ETag = string(getETag.ETag) + } + + var getMod internal.GetLastModified + if err := resp.DecodeProp(&getMod); err != nil && !internal.IsNotFound(err) { + return nil, err + } + fi.ModTime = time.Time(getMod.LastModified) + + return fi, nil +} + +func (c *Client) Stat(ctx context.Context, name string) (*FileInfo, error) { + resp, err := c.ic.PropFindFlat(ctx, name, fileInfoPropFind) + if err != nil { + return nil, err + } + return fileInfoFromResponse(resp) +} + +func (c *Client) Open(ctx context.Context, name string) (io.ReadCloser, error) { + req, err := c.ic.NewRequest(http.MethodGet, name, nil) + if err != nil { + return nil, err + } + + resp, err := c.ic.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + return resp.Body, nil +} + +func (c *Client) Readdir(ctx context.Context, name string, recursive bool) ([]FileInfo, error) { + depth := internal.DepthOne + if recursive { + depth = internal.DepthInfinity + } + + ms, err := c.ic.PropFind(ctx, name, depth, fileInfoPropFind) + if err != nil { + return nil, err + } + + l := make([]FileInfo, 0, len(ms.Responses)) + for _, resp := range ms.Responses { + fi, err := fileInfoFromResponse(&resp) + if err != nil { + return l, err + } + l = append(l, *fi) + } + + return l, nil +} + +type fileWriter struct { + pw *io.PipeWriter + done <-chan error +} + +func (fw *fileWriter) Write(b []byte) (int, error) { + return fw.pw.Write(b) +} + +func (fw *fileWriter) Close() error { + if err := fw.pw.Close(); err != nil { + return err + } + return <-fw.done +} + +func (c *Client) Create(ctx context.Context, name string) (io.WriteCloser, error) { + pr, pw := io.Pipe() + + req, err := c.ic.NewRequest(http.MethodPut, name, pr) + if err != nil { + pw.Close() + return nil, err + } + + done := make(chan error, 1) + go func() { + resp, err := c.ic.Do(req.WithContext(ctx)) + if err != nil { + done <- err + return + } + resp.Body.Close() + done <- nil + }() + + return &fileWriter{pw, done}, nil +} + +func (c *Client) RemoveAll(ctx context.Context, name string) error { + req, err := c.ic.NewRequest(http.MethodDelete, name, nil) + if err != nil { + return err + } + + resp, err := c.ic.Do(req.WithContext(ctx)) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +func (c *Client) Mkdir(ctx context.Context, name string) error { + req, err := c.ic.NewRequest("MKCOL", name, nil) + if err != nil { + return err + } + + resp, err := c.ic.Do(req.WithContext(ctx)) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +func (c *Client) CopyAll(ctx context.Context, name, dest string, overwrite bool) error { + req, err := c.ic.NewRequest("COPY", name, nil) + if err != nil { + return err + } + + req.Header.Set("Destination", c.ic.ResolveHref(dest).String()) + req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite)) + + resp, err := c.ic.Do(req.WithContext(ctx)) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +func (c *Client) MoveAll(ctx context.Context, name, dest string, overwrite bool) error { + req, err := c.ic.NewRequest("MOVE", name, nil) + if err != nil { + return err + } + + req.Header.Set("Destination", c.ic.ResolveHref(dest).String()) + req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite)) + + resp, err := c.ic.Do(req.WithContext(ctx)) + if err != nil { + return err + } + resp.Body.Close() + return nil } diff --git a/util/sync/driveClient/internal/client.go b/util/sync/driveClient/internal/client.go new file mode 100644 index 00000000..718f436a --- /dev/null +++ b/util/sync/driveClient/internal/client.go @@ -0,0 +1,256 @@ +package internal + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "io" + "mime" + "net" + "net/http" + "net/url" + "path" + "strings" + "unicode" +) + +// DiscoverContextURL performs a DNS-based CardDAV/CalDAV service discovery as +// described in RFC 6352 section 11. It returns the URL to the CardDAV server. +func DiscoverContextURL(ctx context.Context, service, domain string) (string, error) { + var resolver net.Resolver + + // Only lookup TLS records, plaintext connections are insecure + _, addrs, err := resolver.LookupSRV(ctx, service+"s", "tcp", domain) + if dnsErr, ok := err.(*net.DNSError); ok { + if dnsErr.IsTemporary { + return "", err + } + } else if err != nil { + return "", err + } + + if len(addrs) == 0 { + return "", fmt.Errorf("webdav: domain doesn't have an SRV record") + } + addr := addrs[0] + + target := strings.TrimSuffix(addr.Target, ".") + if target == "" { + return "", fmt.Errorf("webdav: empty target in SRV record") + } + + // TODO: perform a TXT lookup, check for a "path" key in the response + u := url.URL{Scheme: "https"} + if addr.Port == 443 { + u.Host = target + } else { + u.Host = fmt.Sprintf("%v:%v", target, addr.Port) + } + u.Path = "/.well-known/" + service + return u.String(), nil +} + +// HTTPClient performs HTTP requests. It's implemented by *http.Client. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type Client struct { + http HTTPClient + endpoint *url.URL +} + +func NewClient(c HTTPClient, endpoint string) (*Client, error) { + if c == nil { + c = http.DefaultClient + } + + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + if u.Path == "" { + // This is important to avoid issues with path.Join + u.Path = "/" + } + return &Client{http: c, endpoint: u}, nil +} + +func (c *Client) ResolveHref(p string) *url.URL { + if !strings.HasPrefix(p, "/") { + p = path.Join(c.endpoint.Path, p) + } + return &url.URL{ + Scheme: c.endpoint.Scheme, + User: c.endpoint.User, + Host: c.endpoint.Host, + Path: p, + } +} + +func (c *Client) NewRequest(method string, path string, body io.Reader) (*http.Request, error) { + return http.NewRequest(method, c.ResolveHref(path).String(), body) +} + +func (c *Client) NewXMLRequest(method string, path string, v interface{}) (*http.Request, error) { + var buf bytes.Buffer + buf.WriteString(xml.Header) + if err := xml.NewEncoder(&buf).Encode(v); err != nil { + return nil, err + } + + req, err := c.NewRequest(method, path, &buf) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", "text/xml; charset=\"utf-8\"") + + return req, nil +} + +func (c *Client) Do(req *http.Request) (*http.Response, error) { + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode/100 != 2 { + defer resp.Body.Close() + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/plain" + } + + var wrappedErr error + t, _, _ := mime.ParseMediaType(contentType) + if t == "application/xml" || t == "text/xml" { + var davErr Error + if err := xml.NewDecoder(resp.Body).Decode(&davErr); err != nil { + wrappedErr = err + } else { + wrappedErr = &davErr + } + } else if strings.HasPrefix(t, "text/") { + lr := io.LimitedReader{R: resp.Body, N: 1024} + var buf bytes.Buffer + io.Copy(&buf, &lr) + resp.Body.Close() + if s := strings.TrimSpace(buf.String()); s != "" { + if lr.N == 0 { + s += " […]" + } + wrappedErr = fmt.Errorf("%v", s) + } + } + return nil, &HTTPError{Code: resp.StatusCode, Err: wrappedErr} + } + return resp, nil +} + +func (c *Client) DoMultiStatus(req *http.Request) (*MultiStatus, error) { + resp, err := c.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMultiStatus { + return nil, fmt.Errorf("HTTP multi-status request failed: %v", resp.Status) + } + + // TODO: the response can be quite large, support streaming Response elements + var ms MultiStatus + if err := xml.NewDecoder(resp.Body).Decode(&ms); err != nil { + return nil, err + } + + return &ms, nil +} + +func (c *Client) PropFind(ctx context.Context, path string, depth Depth, propfind *PropFind) (*MultiStatus, error) { + req, err := c.NewXMLRequest("PROPFIND", path, propfind) + if err != nil { + return nil, err + } + + req.Header.Add("Depth", depth.String()) + + return c.DoMultiStatus(req.WithContext(ctx)) +} + +// PropfindFlat performs a PROPFIND request with a zero depth. +func (c *Client) PropFindFlat(ctx context.Context, path string, propfind *PropFind) (*Response, error) { + ms, err := c.PropFind(ctx, path, DepthZero, propfind) + if err != nil { + return nil, err + } + + // If the client followed a redirect, the Href might be different from the request path + if len(ms.Responses) != 1 { + return nil, fmt.Errorf("PROPFIND with Depth: 0 returned %d responses", len(ms.Responses)) + } + return &ms.Responses[0], nil +} + +func parseCommaSeparatedSet(values []string, upper bool) map[string]bool { + m := make(map[string]bool) + for _, v := range values { + fields := strings.FieldsFunc(v, func(r rune) bool { + return unicode.IsSpace(r) || r == ',' + }) + for _, f := range fields { + if upper { + f = strings.ToUpper(f) + } else { + f = strings.ToLower(f) + } + m[f] = true + } + } + return m +} + +func (c *Client) Options(ctx context.Context, path string) (classes map[string]bool, methods map[string]bool, err error) { + req, err := c.NewRequest(http.MethodOptions, path, nil) + if err != nil { + return nil, nil, err + } + + resp, err := c.Do(req.WithContext(ctx)) + if err != nil { + return nil, nil, err + } + resp.Body.Close() + + classes = parseCommaSeparatedSet(resp.Header["Dav"], false) + if !classes["1"] { + return nil, nil, fmt.Errorf("webdav: server doesn't support DAV class 1") + } + + methods = parseCommaSeparatedSet(resp.Header["Allow"], true) + return classes, methods, nil +} + +// SyncCollection perform a `sync-collection` REPORT operation on a resource +func (c *Client) SyncCollection(ctx context.Context, path, syncToken string, level Depth, limit *Limit, prop *Prop) (*MultiStatus, error) { + q := SyncCollectionQuery{ + SyncToken: syncToken, + SyncLevel: level.String(), + Limit: limit, + Prop: prop, + } + + req, err := c.NewXMLRequest("REPORT", path, &q) + if err != nil { + return nil, err + } + + ms, err := c.DoMultiStatus(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + return ms, nil +} diff --git a/util/sync/driveClient/internal/elements.go b/util/sync/driveClient/internal/elements.go new file mode 100644 index 00000000..db7d9603 --- /dev/null +++ b/util/sync/driveClient/internal/elements.go @@ -0,0 +1,452 @@ +package internal + +import ( + "encoding/xml" + "errors" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +const Namespace = "DAV:" + +var ( + ResourceTypeName = xml.Name{Namespace, "resourcetype"} + DisplayNameName = xml.Name{Namespace, "displayname"} + GetContentLengthName = xml.Name{Namespace, "getcontentlength"} + GetContentTypeName = xml.Name{Namespace, "getcontenttype"} + GetLastModifiedName = xml.Name{Namespace, "getlastmodified"} + GetETagName = xml.Name{Namespace, "getetag"} + + CurrentUserPrincipalName = xml.Name{Namespace, "current-user-principal"} +) + +type Status struct { + Code int + Text string +} + +func (s *Status) MarshalText() ([]byte, error) { + text := s.Text + if text == "" { + text = http.StatusText(s.Code) + } + return []byte(fmt.Sprintf("HTTP/1.1 %v %v", s.Code, text)), nil +} + +func (s *Status) UnmarshalText(b []byte) error { + if len(b) == 0 { + return nil + } + + parts := strings.SplitN(string(b), " ", 3) + if len(parts) != 3 { + return fmt.Errorf("webdav: invalid HTTP status %q: expected 3 fields", s) + } + code, err := strconv.Atoi(parts[1]) + if err != nil { + return fmt.Errorf("webdav: invalid HTTP status %q: failed to parse code: %v", s, err) + } + + s.Code = code + s.Text = parts[2] + return nil +} + +func (s *Status) Err() error { + if s == nil { + return nil + } + + // TODO: handle 2xx, 3xx + if s.Code != http.StatusOK { + return &HTTPError{Code: s.Code} + } + return nil +} + +type Href url.URL + +func (h *Href) String() string { + u := (*url.URL)(h) + return u.String() +} + +func (h *Href) MarshalText() ([]byte, error) { + return []byte(h.String()), nil +} + +func (h *Href) UnmarshalText(b []byte) error { + u, err := url.Parse(string(b)) + if err != nil { + return err + } + *h = Href(*u) + return nil +} + +// https://tools.ietf.org/html/rfc4918#section-14.16 +type MultiStatus struct { + XMLName xml.Name `xml:"DAV: multistatus"` + Responses []Response `xml:"response"` + ResponseDescription string `xml:"responsedescription,omitempty"` + SyncToken string `xml:"sync-token,omitempty"` +} + +func NewMultiStatus(resps ...Response) *MultiStatus { + return &MultiStatus{Responses: resps} +} + +// https://tools.ietf.org/html/rfc4918#section-14.24 +type Response struct { + XMLName xml.Name `xml:"DAV: response"` + Hrefs []Href `xml:"href"` + PropStats []PropStat `xml:"propstat,omitempty"` + ResponseDescription string `xml:"responsedescription,omitempty"` + Status *Status `xml:"status,omitempty"` + Error *Error `xml:"error,omitempty"` + Location *Location `xml:"location,omitempty"` +} + +func NewOKResponse(path string) *Response { + href := Href{Path: path} + return &Response{ + Hrefs: []Href{href}, + Status: &Status{Code: http.StatusOK}, + } +} + +func NewErrorResponse(path string, err error) *Response { + code := http.StatusInternalServerError + var httpErr *HTTPError + if errors.As(err, &httpErr) { + code = httpErr.Code + } + + var errElt *Error + errors.As(err, &errElt) + + href := Href{Path: path} + return &Response{ + Hrefs: []Href{href}, + Status: &Status{Code: code}, + ResponseDescription: err.Error(), + Error: errElt, + } +} + +func (resp *Response) Err() error { + if resp.Status == nil || resp.Status.Code/100 == 2 { + return nil + } + + var err error + if resp.Error != nil { + err = resp.Error + } + if resp.ResponseDescription != "" { + if err != nil { + err = fmt.Errorf("%v (%w)", resp.ResponseDescription, err) + } else { + err = fmt.Errorf("%v", resp.ResponseDescription) + } + } + + return &HTTPError{ + Code: resp.Status.Code, + Err: err, + } +} + +func (resp *Response) Path() (string, error) { + err := resp.Err() + var path string + if len(resp.Hrefs) == 1 { + path = resp.Hrefs[0].Path + } else if err == nil { + err = fmt.Errorf("webdav: malformed response: expected exactly one href element, got %v", len(resp.Hrefs)) + } + return path, err +} + +func (resp *Response) DecodeProp(values ...interface{}) error { + for _, v := range values { + // TODO wrap errors with more context (XML name) + name, err := valueXMLName(v) + if err != nil { + return err + } + if err := resp.Err(); err != nil { + return newPropError(name, err) + } + for _, propstat := range resp.PropStats { + raw := propstat.Prop.Get(name) + if raw == nil { + continue + } + if err := propstat.Status.Err(); err != nil { + return newPropError(name, err) + } + if err := raw.Decode(v); err != nil { + return newPropError(name, err) + } + return nil + } + return newPropError(name, &HTTPError{ + Code: http.StatusNotFound, + Err: fmt.Errorf("missing property"), + }) + } + + return nil +} + +func newPropError(name xml.Name, err error) error { + return fmt.Errorf("property <%v %v>: %w", name.Space, name.Local, err) +} + +func (resp *Response) EncodeProp(code int, v interface{}) error { + raw, err := EncodeRawXMLElement(v) + if err != nil { + return err + } + + for i := range resp.PropStats { + propstat := &resp.PropStats[i] + if propstat.Status.Code == code { + propstat.Prop.Raw = append(propstat.Prop.Raw, *raw) + return nil + } + } + + resp.PropStats = append(resp.PropStats, PropStat{ + Status: Status{Code: code}, + Prop: Prop{Raw: []RawXMLValue{*raw}}, + }) + return nil +} + +// https://tools.ietf.org/html/rfc4918#section-14.9 +type Location struct { + XMLName xml.Name `xml:"DAV: location"` + Href Href `xml:"href"` +} + +// https://tools.ietf.org/html/rfc4918#section-14.22 +type PropStat struct { + XMLName xml.Name `xml:"DAV: propstat"` + Prop Prop `xml:"prop"` + Status Status `xml:"status"` + ResponseDescription string `xml:"responsedescription,omitempty"` + Error *Error `xml:"error,omitempty"` +} + +// https://tools.ietf.org/html/rfc4918#section-14.18 +type Prop struct { + XMLName xml.Name `xml:"DAV: prop"` + Raw []RawXMLValue `xml:",any"` +} + +func EncodeProp(values ...interface{}) (*Prop, error) { + l := make([]RawXMLValue, len(values)) + for i, v := range values { + raw, err := EncodeRawXMLElement(v) + if err != nil { + return nil, err + } + l[i] = *raw + } + return &Prop{Raw: l}, nil +} + +func (p *Prop) Get(name xml.Name) *RawXMLValue { + for i := range p.Raw { + raw := &p.Raw[i] + if n, ok := raw.XMLName(); ok && name == n { + return raw + } + } + return nil +} + +func (p *Prop) Decode(v interface{}) error { + name, err := valueXMLName(v) + if err != nil { + return err + } + + raw := p.Get(name) + if raw == nil { + return HTTPErrorf(http.StatusNotFound, "missing property %s", name) + } + + return raw.Decode(v) +} + +// https://tools.ietf.org/html/rfc4918#section-14.20 +type PropFind struct { + XMLName xml.Name `xml:"DAV: propfind"` + Prop *Prop `xml:"prop,omitempty"` + AllProp *struct{} `xml:"allprop,omitempty"` + Include *Include `xml:"include,omitempty"` + PropName *struct{} `xml:"propname,omitempty"` +} + +func xmlNamesToRaw(names []xml.Name) []RawXMLValue { + l := make([]RawXMLValue, len(names)) + for i, name := range names { + l[i] = *NewRawXMLElement(name, nil, nil) + } + return l +} + +func NewPropNamePropFind(names ...xml.Name) *PropFind { + return &PropFind{Prop: &Prop{Raw: xmlNamesToRaw(names)}} +} + +// https://tools.ietf.org/html/rfc4918#section-14.8 +type Include struct { + XMLName xml.Name `xml:"DAV: include"` + Raw []RawXMLValue `xml:",any"` +} + +// https://tools.ietf.org/html/rfc4918#section-15.9 +type ResourceType struct { + XMLName xml.Name `xml:"DAV: resourcetype"` + Raw []RawXMLValue `xml:",any"` +} + +func NewResourceType(names ...xml.Name) *ResourceType { + return &ResourceType{Raw: xmlNamesToRaw(names)} +} + +func (t *ResourceType) Is(name xml.Name) bool { + for _, raw := range t.Raw { + if n, ok := raw.XMLName(); ok && name == n { + return true + } + } + return false +} + +var CollectionName = xml.Name{Namespace, "collection"} + +// https://tools.ietf.org/html/rfc4918#section-15.4 +type GetContentLength struct { + XMLName xml.Name `xml:"DAV: getcontentlength"` + Length int64 `xml:",chardata"` +} + +// https://tools.ietf.org/html/rfc4918#section-15.5 +type GetContentType struct { + XMLName xml.Name `xml:"DAV: getcontenttype"` + Type string `xml:",chardata"` +} + +type Time time.Time + +func (t *Time) UnmarshalText(b []byte) error { + tt, err := http.ParseTime(string(b)) + if err != nil { + return err + } + *t = Time(tt) + return nil +} + +func (t *Time) MarshalText() ([]byte, error) { + s := time.Time(*t).UTC().Format(http.TimeFormat) + return []byte(s), nil +} + +// https://tools.ietf.org/html/rfc4918#section-15.7 +type GetLastModified struct { + XMLName xml.Name `xml:"DAV: getlastmodified"` + LastModified Time `xml:",chardata"` +} + +// https://tools.ietf.org/html/rfc4918#section-15.6 +type GetETag struct { + XMLName xml.Name `xml:"DAV: getetag"` + ETag ETag `xml:",chardata"` +} + +type ETag string + +func (etag *ETag) UnmarshalText(b []byte) error { + s, err := strconv.Unquote(string(b)) + if err != nil { + return fmt.Errorf("webdav: failed to unquote ETag: %v", err) + } + *etag = ETag(s) + return nil +} + +func (etag ETag) MarshalText() ([]byte, error) { + return []byte(etag.String()), nil +} + +func (etag ETag) String() string { + return fmt.Sprintf("%q", string(etag)) +} + +// https://tools.ietf.org/html/rfc4918#section-14.5 +type Error struct { + XMLName xml.Name `xml:"DAV: error"` + Raw []RawXMLValue `xml:",any"` +} + +func (err *Error) Error() string { + b, _ := xml.Marshal(err) + return string(b) +} + +// https://tools.ietf.org/html/rfc4918#section-15.2 +type DisplayName struct { + XMLName xml.Name `xml:"DAV: displayname"` + Name string `xml:",chardata"` +} + +// https://tools.ietf.org/html/rfc5397#section-3 +type CurrentUserPrincipal struct { + XMLName xml.Name `xml:"DAV: current-user-principal"` + Href Href `xml:"href,omitempty"` + Unauthenticated *struct{} `xml:"unauthenticated,omitempty"` +} + +// https://tools.ietf.org/html/rfc4918#section-14.19 +type PropertyUpdate struct { + XMLName xml.Name `xml:"DAV: propertyupdate"` + Remove []Remove `xml:"remove"` + Set []Set `xml:"set"` +} + +// https://tools.ietf.org/html/rfc4918#section-14.23 +type Remove struct { + XMLName xml.Name `xml:"DAV: remove"` + Prop Prop `xml:"prop"` +} + +// https://tools.ietf.org/html/rfc4918#section-14.26 +type Set struct { + XMLName xml.Name `xml:"DAV: set"` + Prop Prop `xml:"prop"` +} + +// https://tools.ietf.org/html/rfc6578#section-6.1 +type SyncCollectionQuery struct { + XMLName xml.Name `xml:"DAV: sync-collection"` + SyncToken string `xml:"sync-token"` + Limit *Limit `xml:"limit,omitempty"` + SyncLevel string `xml:"sync-level"` + Prop *Prop `xml:"prop"` +} + +// https://tools.ietf.org/html/rfc5323#section-5.17 +type Limit struct { + XMLName xml.Name `xml:"DAV: limit"` + NResults uint `xml:"nresults"` +} diff --git a/util/sync/driveClient/internal/internal.go b/util/sync/driveClient/internal/internal.go new file mode 100644 index 00000000..1f3e0e52 --- /dev/null +++ b/util/sync/driveClient/internal/internal.go @@ -0,0 +1,108 @@ +package internal // Package internal provides low-level helpers for WebDAV clients and servers. +import ( + "errors" + "fmt" + "net/http" +) + +// Depth indicates whether a request applies to the resource's members. It's +// defined in RFC 4918 section 10.2. +type Depth int + +const ( + // DepthZero indicates that the request applies only to the resource. + DepthZero Depth = 0 + // DepthOne indicates that the request applies to the resource and its + // internal members only. + DepthOne Depth = 1 + // DepthInfinity indicates that the request applies to the resource and all + // of its members. + DepthInfinity Depth = -1 +) + +// ParseDepth parses a Depth header. +func ParseDepth(s string) (Depth, error) { + switch s { + case "0": + return DepthZero, nil + case "1": + return DepthOne, nil + case "infinity": + return DepthInfinity, nil + } + return 0, fmt.Errorf("webdav: invalid Depth value") +} + +// String formats the depth. +func (d Depth) String() string { + switch d { + case DepthZero: + return "0" + case DepthOne: + return "1" + case DepthInfinity: + return "infinity" + } + panic("webdav: invalid Depth value") +} + +// ParseOverwrite parses an Overwrite header. +func ParseOverwrite(s string) (bool, error) { + switch s { + case "T": + return true, nil + case "F": + return false, nil + } + return false, fmt.Errorf("webdav: invalid Overwrite value") +} + +// FormatOverwrite formats an Overwrite header. +func FormatOverwrite(overwrite bool) string { + if overwrite { + return "T" + } else { + return "F" + } +} + +type HTTPError struct { + Code int + Err error +} + +func HTTPErrorFromError(err error) *HTTPError { + if err == nil { + return nil + } + if httpErr, ok := err.(*HTTPError); ok { + return httpErr + } else { + return &HTTPError{http.StatusInternalServerError, err} + } +} + +func IsNotFound(err error) bool { + var httpErr *HTTPError + if errors.As(err, &httpErr) { + return httpErr.Code == http.StatusNotFound + } + return false +} + +func HTTPErrorf(code int, format string, a ...interface{}) *HTTPError { + return &HTTPError{code, fmt.Errorf(format, a...)} +} + +func (err *HTTPError) Error() string { + s := fmt.Sprintf("%v %v", err.Code, http.StatusText(err.Code)) + if err.Err != nil { + return fmt.Sprintf("%v: %v", s, err.Err) + } else { + return s + } +} + +func (err *HTTPError) Unwrap() error { + return err.Err +} diff --git a/util/sync/driveClient/internal/xml.go b/util/sync/driveClient/internal/xml.go new file mode 100644 index 00000000..2f4c61ca --- /dev/null +++ b/util/sync/driveClient/internal/xml.go @@ -0,0 +1,175 @@ +package internal + +import ( + "encoding/xml" + "fmt" + "io" + "reflect" + "strings" +) + +// RawXMLValue is a raw XML value. It implements xml.Unmarshaler and +// xml.Marshaler and can be used to delay XML decoding or precompute an XML +// encoding. +type RawXMLValue struct { + tok xml.Token // guaranteed not to be xml.EndElement + children []RawXMLValue + + // Unfortunately encoding/xml doesn't offer TokenWriter, so we need to + // cache outgoing data. + out interface{} +} + +// NewRawXMLElement creates a new RawXMLValue for an element. +func NewRawXMLElement(name xml.Name, attr []xml.Attr, children []RawXMLValue) *RawXMLValue { + return &RawXMLValue{tok: xml.StartElement{name, attr}, children: children} +} + +// EncodeRawXMLElement encodes a value into a new RawXMLValue. The XML value +// can only be used for marshalling. +func EncodeRawXMLElement(v interface{}) (*RawXMLValue, error) { + return &RawXMLValue{out: v}, nil +} + +// UnmarshalXML implements xml.Unmarshaler. +func (val *RawXMLValue) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + val.tok = start + val.children = nil + val.out = nil + + for { + tok, err := d.Token() + if err != nil { + return err + } + switch tok := tok.(type) { + case xml.StartElement: + child := RawXMLValue{} + if err := child.UnmarshalXML(d, tok); err != nil { + return err + } + val.children = append(val.children, child) + case xml.EndElement: + return nil + default: + val.children = append(val.children, RawXMLValue{tok: xml.CopyToken(tok)}) + } + } +} + +// MarshalXML implements xml.Marshaler. +func (val *RawXMLValue) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + if val.out != nil { + return e.Encode(val.out) + } + + switch tok := val.tok.(type) { + case xml.StartElement: + if err := e.EncodeToken(tok); err != nil { + return err + } + for _, child := range val.children { + // TODO: find a sensible value for the start argument? + if err := child.MarshalXML(e, xml.StartElement{}); err != nil { + return err + } + } + return e.EncodeToken(tok.End()) + case xml.EndElement: + panic("unexpected end element") + default: + return e.EncodeToken(tok) + } +} + +var _ xml.Marshaler = (*RawXMLValue)(nil) +var _ xml.Unmarshaler = (*RawXMLValue)(nil) + +func (val *RawXMLValue) Decode(v interface{}) error { + return xml.NewTokenDecoder(val.TokenReader()).Decode(&v) +} + +func (val *RawXMLValue) XMLName() (name xml.Name, ok bool) { + if start, ok := val.tok.(xml.StartElement); ok { + return start.Name, true + } + return xml.Name{}, false +} + +// TokenReader returns a stream of tokens for the XML value. +func (val *RawXMLValue) TokenReader() xml.TokenReader { + if val.out != nil { + panic("webdav: called RawXMLValue.TokenReader on a marshal-only XML value") + } + return &rawXMLValueReader{val: val} +} + +type rawXMLValueReader struct { + val *RawXMLValue + start, end bool + child int + childReader xml.TokenReader +} + +func (tr *rawXMLValueReader) Token() (xml.Token, error) { + if tr.end { + return nil, io.EOF + } + + start, ok := tr.val.tok.(xml.StartElement) + if !ok { + tr.end = true + return tr.val.tok, nil + } + + if !tr.start { + tr.start = true + return start, nil + } + + for tr.child < len(tr.val.children) { + if tr.childReader == nil { + tr.childReader = tr.val.children[tr.child].TokenReader() + } + + tok, err := tr.childReader.Token() + if err == io.EOF { + tr.childReader = nil + tr.child++ + } else { + return tok, err + } + } + + tr.end = true + return start.End(), nil +} + +var _ xml.TokenReader = (*rawXMLValueReader)(nil) + +func valueXMLName(v interface{}) (xml.Name, error) { + t := reflect.TypeOf(v) + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return xml.Name{}, fmt.Errorf("webdav: %T is not a struct", v) + } + nameField, ok := t.FieldByName("XMLName") + if !ok { + return xml.Name{}, fmt.Errorf("webdav: %T is missing an XMLName struct field", v) + } + if nameField.Type != reflect.TypeOf(xml.Name{}) { + return xml.Name{}, fmt.Errorf("webdav: %T.XMLName isn't an xml.Name", v) + } + tag := nameField.Tag.Get("xml") + if tag == "" { + return xml.Name{}, fmt.Errorf(`webdav: %T.XMLName is missing an "xml" tag`, v) + } + name := strings.Split(tag, ",")[0] + nameParts := strings.Split(name, " ") + if len(nameParts) != 2 { + return xml.Name{}, fmt.Errorf("webdav: expected a namespace and local name in %T.XMLName's xml tag", v) + } + return xml.Name{nameParts[0], nameParts[1]}, nil +} diff --git a/util/sync/driveClient/model.go b/util/sync/driveClient/model.go new file mode 100644 index 00000000..100994d3 --- /dev/null +++ b/util/sync/driveClient/model.go @@ -0,0 +1,119 @@ +package driveClient + +import ( + "errors" + "fmt" + "net/http" + "time" +) + +// Depth indicates whether a request applies to the resource's members. It's +// defined in RFC 4918 section 10.2. +type Depth int + +const ( + // DepthZero indicates that the request applies only to the resource. + DepthZero Depth = 0 + // DepthOne indicates that the request applies to the resource and its + // internal members only. + DepthOne Depth = 1 + // DepthInfinity indicates that the request applies to the resource and all + // of its members. + DepthInfinity Depth = -1 +) + +// ParseDepth parses a Depth header. +func ParseDepth(s string) (Depth, error) { + switch s { + case "0": + return DepthZero, nil + case "1": + return DepthOne, nil + case "infinity": + return DepthInfinity, nil + } + return 0, fmt.Errorf("webdav: invalid Depth value") +} + +// String formats the depth. +func (d Depth) String() string { + switch d { + case DepthZero: + return "0" + case DepthOne: + return "1" + case DepthInfinity: + return "infinity" + } + panic("webdav: invalid Depth value") +} + +// ParseOverwrite parses an Overwrite header. +func ParseOverwrite(s string) (bool, error) { + switch s { + case "T": + return true, nil + case "F": + return false, nil + } + return false, fmt.Errorf("webdav: invalid Overwrite value") +} + +// FormatOverwrite formats an Overwrite header. +func FormatOverwrite(overwrite bool) string { + if overwrite { + return "T" + } else { + return "F" + } +} + +type HTTPError struct { + Code int + Err error +} + +func HTTPErrorFromError(err error) *HTTPError { + if err == nil { + return nil + } + if httpErr, ok := err.(*HTTPError); ok { + return httpErr + } else { + return &HTTPError{http.StatusInternalServerError, err} + } +} + +func IsNotFound(err error) bool { + var httpErr *HTTPError + if errors.As(err, &httpErr) { + return httpErr.Code == http.StatusNotFound + } + return false +} + +func HTTPErrorf(code int, format string, a ...interface{}) *HTTPError { + return &HTTPError{code, fmt.Errorf(format, a...)} +} + +func (err *HTTPError) Error() string { + s := fmt.Sprintf("%v %v", err.Code, http.StatusText(err.Code)) + if err.Err != nil { + return fmt.Sprintf("%v: %v", s, err.Err) + } else { + return s + } +} + +func (err *HTTPError) Unwrap() error { + return err.Err +} + +type FileInfo struct { + Path string + Size int64 + ModTime time.Time + IsDir bool + MIMEType string + ETag string +}