server : add more parameters to server api (#1754)

* feat(server): add more parameters to server api

* fix(server): reset params to original parsed values for each request
This commit is contained in:
George Hindle 2024-01-12 11:42:52 +00:00 committed by GitHub
parent 6b01e3fedd
commit fbcb52d3cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -397,6 +397,13 @@ std::string output_str(struct whisper_context * ctx, const whisper_params & para
return result.str();
}
bool parse_str_to_bool(const std::string & s) {
if (s == "true" || s == "1" || s == "yes" || s == "y") {
return true;
}
return false;
}
void get_req_parameters(const Request & req, whisper_params & params)
{
if (req.has_file("offset_t"))
@ -415,6 +422,62 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.max_context = std::stoi(req.get_file_value("max_context").content);
}
if (req.has_file("max_len"))
{
params.max_len = std::stoi(req.get_file_value("max_len").content);
}
if (req.has_file("best_of"))
{
params.best_of = std::stoi(req.get_file_value("best_of").content);
}
if (req.has_file("beam_size"))
{
params.beam_size = std::stoi(req.get_file_value("beam_size").content);
}
if (req.has_file("word_thold"))
{
params.word_thold = std::stof(req.get_file_value("word_thold").content);
}
if (req.has_file("entropy_thold"))
{
params.entropy_thold = std::stof(req.get_file_value("entropy_thold").content);
}
if (req.has_file("logprob_thold"))
{
params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content);
}
if (req.has_file("debug_mode"))
{
params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content);
}
if (req.has_file("translate"))
{
params.translate = parse_str_to_bool(req.get_file_value("translate").content);
}
if (req.has_file("diarize"))
{
params.diarize = parse_str_to_bool(req.get_file_value("diarize").content);
}
if (req.has_file("tinydiarize"))
{
params.tinydiarize = parse_str_to_bool(req.get_file_value("tinydiarize").content);
}
if (req.has_file("split_on_word"))
{
params.split_on_word = parse_str_to_bool(req.get_file_value("split_on_word").content);
}
if (req.has_file("no_timestamps"))
{
params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content);
}
if (req.has_file("language"))
{
params.language = req.get_file_value("language").content;
}
if (req.has_file("detect_language"))
{
params.detect_language = parse_str_to_bool(req.get_file_value("detect_language").content);
}
if (req.has_file("prompt"))
{
params.prompt = req.get_file_value("prompt").content;
@ -482,6 +545,9 @@ int main(int argc, char ** argv) {
std::string const default_content = "<html>hello</html>";
// store default params so we can reset after each inference request
whisper_params default_params = params;
// this is only called if no index.html is found in the public --path
svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
res.set_content(default_content, "text/html");
@ -724,6 +790,9 @@ int main(int argc, char ** argv) {
"application/json");
}
// reset params to thier defaults
params = default_params;
// return whisper model mutex lock
whisper_mutex.unlock();
});