From 9edbd0a204cd552cb327f80fd82f0ecc3ad7d1cf Mon Sep 17 00:00:00 2001 From: Neil Chudleigh Date: Mon, 25 Sep 2023 08:45:15 -0700 Subject: [PATCH] extra: Add benchmark script implemented in Python (#1298) * Create bench.py * Various benchmark results * Update benchmark script with hardware name, and file checks * Remove old benchmark results * Add git shorthash * Round to 2 digits on calculated floats * Fix the header reference when sorting results * FIx order of models * Parse file name * Simplify filecheck * Improve print run print statement * Use simplified model name * Update benchmark_results.csv * Process single or lists of processors and threads * Ignore benchmark results, dont check in * Move bench.py to extra folder * Readme section on how to use * Move command to correct location * Use separate list for models that exist * Handle subprocess error in git short hash check * Fix filtered models list initialization --- .gitignore | 2 + README.md | 13 +++ extra/bench.py | 222 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 237 insertions(+) create mode 100644 extra/bench.py diff --git a/.gitignore b/.gitignore index b30a1d19..ab1fc2e3 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,5 @@ models/*.mlpackage bindings/java/.gradle/ bindings/java/.idea/ .idea/ + +benchmark_results.csv diff --git a/README.md b/README.md index 894e0e03..5831797f 100644 --- a/README.md +++ b/README.md @@ -709,6 +709,19 @@ took to execute it. The results are summarized in the following Github issue: [Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89) +Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py). + +You can run it with the following command, by default it will run against any standard model in the models folder. + +```bash +python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2 +``` + +It is written in python with the intention of being easy to modify and extend for your benchmarking use case. + +It outputs a csv file with the results of the benchmarking. + + ## ggml format The original models are converted to a custom binary format. This allows to pack everything needed into a single file: diff --git a/extra/bench.py b/extra/bench.py new file mode 100644 index 00000000..74956e72 --- /dev/null +++ b/extra/bench.py @@ -0,0 +1,222 @@ +import os +import subprocess +import re +import csv +import wave +import contextlib +import argparse + + +# Custom action to handle comma-separated list +class ListAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, [int(val) for val in values.split(",")]) + + +parser = argparse.ArgumentParser(description="Benchmark the speech recognition model") + +# Define the argument to accept a list +parser.add_argument( + "-t", + "--threads", + dest="threads", + action=ListAction, + default=[4], + help="List of thread counts to benchmark (comma-separated, default: 4)", +) + +parser.add_argument( + "-p", + "--processors", + dest="processors", + action=ListAction, + default=[1], + help="List of processor counts to benchmark (comma-separated, default: 1)", +) + + +parser.add_argument( + "-f", + "--filename", + type=str, + default="./samples/jfk.wav", + help="Relative path of the file to transcribe (default: ./samples/jfk.wav)", +) + +# Parse the command line arguments +args = parser.parse_args() + +sample_file = args.filename + +threads = args.threads +processors = args.processors + +# Define the models, threads, and processor counts to benchmark +models = [ + "ggml-tiny.en.bin", + "ggml-tiny.bin", + "ggml-base.en.bin", + "ggml-base.bin", + "ggml-small.en.bin", + "ggml-small.bin", + "ggml-medium.en.bin", + "ggml-medium.bin", + "ggml-large.bin", +] + + +metal_device = "" + +# Initialize a dictionary to hold the results +results = {} + +gitHashHeader = "Commit" +modelHeader = "Model" +hardwareHeader = "Hardware" +recordingLengthHeader = "Recording Length (seconds)" +threadHeader = "Thread" +processorCountHeader = "Processor Count" +loadTimeHeader = "Load Time (ms)" +sampleTimeHeader = "Sample Time (ms)" +encodeTimeHeader = "Encode Time (ms)" +decodeTimeHeader = "Decode Time (ms)" +sampleTimePerRunHeader = "Sample Time per Run (ms)" +encodeTimePerRunHeader = "Encode Time per Run (ms)" +decodeTimePerRunHeader = "Decode Time per Run (ms)" +totalTimeHeader = "Total Time (ms)" + + +def check_file_exists(file: str) -> bool: + return os.path.isfile(file) + + +def get_git_short_hash() -> str: + try: + return ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .decode() + .strip() + ) + except subprocess.CalledProcessError as e: + return "" + + +def wav_file_length(file: str = sample_file) -> float: + with contextlib.closing(wave.open(file, "r")) as f: + frames = f.getnframes() + rate = f.getframerate() + duration = frames / float(rate) + return duration + + +def extract_metrics(output: str, label: str) -> tuple[float, float]: + match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output) + time = float(match.group(1)) if match else None + runs = float(match.group(2)) if match else None + return time, runs + + +def extract_device(output: str) -> str: + match = re.search(r"picking default device: (.*)", output) + device = match.group(1) if match else "Not found" + return device + + +# Check if the sample file exists +if not check_file_exists(sample_file): + raise FileNotFoundError(f"Sample file {sample_file} not found") + +recording_length = wav_file_length() + + +# Check that all models exist +# Filter out models from list that are not downloaded +filtered_models = [] +for model in models: + if check_file_exists(f"models/{model}"): + filtered_models.append(model) + else: + print(f"Model {model} not found, removing from list") + +models = filtered_models + +# Loop over each combination of parameters +for model in filtered_models: + for thread in threads: + for processor_count in processors: + # Construct the command to run + cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}" + # Run the command and get the output + process = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + + output = "" + while process.poll() is None: + output += process.stdout.read().decode() + + # Parse the output + load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output) + load_time = float(load_time_match.group(1)) if load_time_match else None + + metal_device = extract_device(output) + sample_time, sample_runs = extract_metrics(output, "sample time") + encode_time, encode_runs = extract_metrics(output, "encode time") + decode_time, decode_runs = extract_metrics(output, "decode time") + + total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output) + total_time = float(total_time_match.group(1)) if total_time_match else None + + model_name = model.replace("ggml-", "").replace(".bin", "") + + print( + f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms" + ) + # Store the times in the results dictionary + results[(model_name, thread, processor_count)] = { + loadTimeHeader: load_time, + sampleTimeHeader: sample_time, + encodeTimeHeader: encode_time, + decodeTimeHeader: decode_time, + sampleTimePerRunHeader: round(sample_time / sample_runs, 2), + encodeTimePerRunHeader: round(encode_time / encode_runs, 2), + decodeTimePerRunHeader: round(decode_time / decode_runs, 2), + totalTimeHeader: total_time, + } + +# Write the results to a CSV file +with open("benchmark_results.csv", "w", newline="") as csvfile: + fieldnames = [ + gitHashHeader, + modelHeader, + hardwareHeader, + recordingLengthHeader, + threadHeader, + processorCountHeader, + loadTimeHeader, + sampleTimeHeader, + encodeTimeHeader, + decodeTimeHeader, + sampleTimePerRunHeader, + encodeTimePerRunHeader, + decodeTimePerRunHeader, + totalTimeHeader, + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + writer.writeheader() + + shortHash = get_git_short_hash() + # Sort the results by total time in ascending order + sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0)) + for params, times in sorted_results: + row = { + gitHashHeader: shortHash, + modelHeader: params[0], + hardwareHeader: metal_device, + recordingLengthHeader: recording_length, + threadHeader: params[1], + processorCountHeader: params[2], + } + row.update(times) + writer.writerow(row)