From c774eec709d153b94be60ebec8c7cb97f3bd82cd Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 7 Mar 2025 02:03:51 -0600 Subject: [PATCH] go : improve model download (#2756) * Updated models download URL * Updated list of models available All of the high efficiency quantized models are rejected when trying to download. They exist on the server. Let's allow them. * added path prefix for whisper-cli in message to user. The message is misleading if this script is called from another script in a different folder. So the message has to be fixed. * undid download URL change I made earlier. Fixed filepath.Join(urlPath, model) bug. * Undid download URL change I made earlier. Seems that the old URL works but only when provided a model to download. Still doesn't explain why there's a different download URL that also works. Please elucidate in docs. * Fixed URLForModel Function's bug filepath.Join is designed for filesystem paths, and it uses backslashes (\) on Windows. URLs, however, require forward slashes (/), so the use of filepath.Join is inappropriate for constructing URLs. The fmt.Sprintf function ensures that forward slashes are used. * Fixed URL trailing / double slash bug Ensure no double slash by trimming trailing '/' from srcUrl if present * Fixed bad download URL, missing ggml prefix Not sure if that was a bug I introduced but it was trying to download without the prefix. * Added question before downloading all models. Added download size estimate HEAD Requests: Efficiently fetches file sizes without downloading the content. Interactive Workflow: Allows the user to make informed decisions about downloading all models. Safe Defaults: Aborts if the user does not explicitly confirm. * Fixed Unbuffered channel warning. warning in context.go : misuse of unbuffered os.Signal channel as argument to signal. The warning indicates that the unbuffered channel used in signal.Notify in context.go may be misused. In Go, unbuffered channels can cause potential deadlocks if signals are sent faster than they are received. * Fixed download size calculation, download URL prefix bug, added link to models URL for user. The URL formatter was prepending the model name to the formatted model name in the URL * Added logs and exes to gitignore * Delete bindings/go/examples/go-model-download/go-model-download.exe * Delete whisper_build.log --- .gitignore | 2 + .../go/examples/go-model-download/context.go | 29 ++--- .../go/examples/go-model-download/main.go | 116 ++++++++++++++++-- models/download-ggml-model.cmd | 15 ++- 4 files changed, 136 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index c1e584db..91368ec5 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,5 @@ cmake-build-debug/ .cxx/ .gradle/ local.properties +.log +.exe \ No newline at end of file diff --git a/bindings/go/examples/go-model-download/context.go b/bindings/go/examples/go-model-download/context.go index 639d8f5b..7d5f0ddb 100644 --- a/bindings/go/examples/go-model-download/context.go +++ b/bindings/go/examples/go-model-download/context.go @@ -9,22 +9,23 @@ import ( // ContextForSignal returns a context object which is cancelled when a signal // is received. It returns nil if no signal parameter is provided func ContextForSignal(signals ...os.Signal) context.Context { - if len(signals) == 0 { - return nil - } + if len(signals) == 0 { + return nil + } - ch := make(chan os.Signal) - ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan os.Signal, 1) // Buffered channel with space for 1 signal + ctx, cancel := context.WithCancel(context.Background()) - // Send message on channel when signal received - signal.Notify(ch, signals...) + // Send message on channel when signal received + signal.Notify(ch, signals...) - // When any signal received, call cancel - go func() { - <-ch - cancel() - }() + // When any signal is received, call cancel + go func() { + <-ch + cancel() + }() - // Return success - return ctx + // Return success + return ctx } + diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go index d0c1cc78..728c6df5 100644 --- a/bindings/go/examples/go-model-download/main.go +++ b/bindings/go/examples/go-model-download/main.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "path/filepath" + "strings" "syscall" "time" ) @@ -17,14 +18,27 @@ import ( // CONSTANTS const ( - srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" // The location of the models - srcExt = ".bin" // Filename extension - bufSize = 1024 * 64 // Size of the buffer used for downloading the model + srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models + srcExt = ".bin" // Filename extension + bufSize = 1024 * 64 // Size of the buffer used for downloading the model ) var ( // The models which will be downloaded, if no model is specified as an argument - modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3", "large-v3-turbo"} + modelNames = []string{ + "tiny", "tiny-q5_1", "tiny-q8_0", + "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0", + "base", "base-q5_1", "base-q8_0", + "base.en", "base.en-q5_1", "base.en-q8_0", + "small", "small-q5_1", "small-q8_0", + "small.en", "small.en-q5_1", "small.en-q8_0", + "medium", "medium-q5_0", "medium-q8_0", + "medium.en", "medium.en-q5_0", "medium.en-q8_0", + "large-v1", + "large-v2", "large-v2-q5_0", "large-v2-q8_0", + "large-v3", "large-v3-q5_0", + "large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0", + } ) var ( @@ -44,7 +58,25 @@ var ( func main() { flag.Usage = func() { name := filepath.Base(flag.CommandLine.Name()) - fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] \n\n", name) + fmt.Fprintf(flag.CommandLine.Output(), ` + Usage: %s [options] [...] + + Options: + -out string Specify the output folder where models will be saved. + Default: Current working directory. + -timeout duration Set the maximum duration for downloading a model. + Example: 10m, 1h (default: 30m0s). + -quiet Suppress all output except errors. + + Examples: + 1. Download a specific model: + %s -out ./models tiny-q8_0 + + 2. Download all models: + %s -out ./models + + `, name, name, name) + flag.PrintDefaults() } flag.Parse() @@ -114,23 +146,87 @@ func GetOut() (string, error) { // GetModels returns the list of models to download func GetModels() []string { if flag.NArg() == 0 { - return modelNames - } else { - return flag.Args() + fmt.Println("No model specified.") + fmt.Println("Preparing to download all models...") + + // Calculate total download size + fmt.Println("Calculating total download size...") + totalSize, err := CalculateTotalDownloadSize(modelNames) + if err != nil { + fmt.Println("Error calculating download sizes:", err) + os.Exit(1) + } + + fmt.Println("View available models: https://huggingface.co/ggerganov/whisper.cpp/tree/main") + fmt.Printf("Total download size: %.2f GB\n", float64(totalSize)/(1024*1024*1024)) + fmt.Println("Would you like to download all models? (y/N)") + + // Prompt for user input + var response string + fmt.Scanln(&response) + if response != "y" && response != "Y" { + fmt.Println("Aborting. Specify a model to download.") + os.Exit(0) + } + + return modelNames // Return all models if confirmed } + return flag.Args() // Return specific models if arguments are provided +} + +func CalculateTotalDownloadSize(models []string) (int64, error) { + var totalSize int64 + client := http.Client{} + + for _, model := range models { + modelURL, err := URLForModel(model) + if err != nil { + return 0, err + } + + // Issue a HEAD request to get the file size + req, err := http.NewRequest("HEAD", modelURL, nil) + if err != nil { + return 0, err + } + + resp, err := client.Do(req) + if err != nil { + return 0, err + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + fmt.Printf("Warning: Unable to fetch size for %s (HTTP %d)\n", model, resp.StatusCode) + continue + } + + size := resp.ContentLength + totalSize += size + } + return totalSize, nil } // URLForModel returns the URL for the given model on huggingface.co func URLForModel(model string) (string, error) { + // Ensure "ggml-" prefix is added only once + if !strings.HasPrefix(model, "ggml-") { + model = "ggml-" + model + } + + // Ensure ".bin" extension is added only once if filepath.Ext(model) != srcExt { model += srcExt } + + // Parse the base URL url, err := url.Parse(srcUrl) if err != nil { return "", err - } else { - url.Path = filepath.Join(url.Path, model) } + + // Ensure no trailing slash in the base URL + url.Path = fmt.Sprintf("%s/%s", strings.TrimSuffix(url.Path, "/"), model) return url.String(), nil } diff --git a/models/download-ggml-model.cmd b/models/download-ggml-model.cmd index f329011d..566aa1bf 100644 --- a/models/download-ggml-model.cmd +++ b/models/download-ggml-model.cmd @@ -8,7 +8,18 @@ popd set argc=0 for %%x in (%*) do set /A argc+=1 -set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo +set models=tiny tiny-q5_1 tiny-q8_0 ^ +tiny.en tiny.en-q5_1 tiny.en-q8_0 ^ +base base-q5_1 base-q8_0 ^ +base.en base.en-q5_1 base.en-q8_0 ^ +small small-q5_1 small-q8_0 ^ +small.en small.en-q5_1 small.en-q8_0 ^ +medium medium-q5_0 medium-q8_0 ^ +medium.en medium.en-q5_0 medium.en-q8_0 ^ +large-v1 ^ +large-v2 large-v2-q5_0 large-v2-q8_0 ^ +large-v3 large-v3-q5_0 ^ +large-v3-turbo large-v3-turbo-q5_0 large-v3-turbo-q8_0 if %argc% neq 1 ( echo. @@ -50,7 +61,7 @@ if %ERRORLEVEL% neq 0 ( echo Done! Model %model% saved in %root_path%\models\ggml-%model%.bin echo You can now use it like this: -echo build\bin\Release\whisper-cli.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav +echo %~dp0build\bin\Release\whisper-cli.exe -m %root_path%\models\ggml-%model%.bin -f %root_path%\samples\jfk.wav goto :eof