diff --git a/examples/control-vector-generator/control-vector-generator.cpp b/examples/control-vector-generator/control-vector-generator.cpp index 3ca2c530520f6..d6c3fa3cd8082 100644 --- a/examples/control-vector-generator/control-vector-generator.cpp +++ b/examples/control-vector-generator/control-vector-generator.cpp @@ -258,6 +258,8 @@ struct ctrl_params { /* pair of prompts to be used for generating final vector */ std::vector positive_entries; std::vector negative_entries; + + bool single_prompt = false; }; struct tokenized_prompt { @@ -310,6 +312,8 @@ static void print_usage(const char * executable) { printf(" default: %s\n", defaults.negative_prompts_file.c_str()); printf(" -cf, --completions-file completions file\n"); printf(" default: %s\n", defaults.completions_file.c_str()); + printf(" --single-prompt use whole file as the prompt (allows for a multiline prompt)\n"); + printf(" default: use each line as a separate prompt\n"); printf(" -nc, --num-completions N number of lines of completions file to use\n"); printf(" default: %d\n", defaults.n_completions); printf(" --batch-pca N batch size used for PCA. Larger batch runs faster, but uses more memory\n"); @@ -377,6 +381,10 @@ static int ctrlvec_params_parse_ex(int argc, char ** argv, ctrl_params & params) throw std::invalid_argument("error: missing argument for " + arg); } } + if (arg == "--single-prompt") { + params.single_prompt = true; + skipme += 1; + } if (arg == "--num-completions" || arg == "-nc") { if (++arg_idx < argc && strncmp(argv[arg_idx], arg_prefix.c_str(), 2) != 0) { try { @@ -435,7 +443,7 @@ static int ctrlvec_params_parse(int argc, char ** argv, ctrl_params & params) { return skipme; } -static std::vector ctrlvec_load_prompt_file(std::string path, bool skip_empty_lines = false) { +static std::vector ctrlvec_load_prompt_file(std::string path, bool skip_empty_lines = false, bool single_prompt = false) { std::vector output; std::ifstream file(path); if (!file.is_open()) { @@ -449,6 +457,14 @@ static std::vector ctrlvec_load_prompt_file(std::string path, bool } } file.close(); + if (single_prompt) { + std::string single_prompt; + for (const auto & line : output) { + single_prompt += line + "\n"; + } + output.clear(); + output.push_back(single_prompt); + } return output; } @@ -510,8 +526,8 @@ static void export_gguf(const std::vector & v_ctrl, const */ static int prepare_entries(ctrl_params & cparams) { // load prompts - std::vector positive_prompts = ctrlvec_load_prompt_file(cparams.positive_prompts_file); - std::vector negative_prompts = ctrlvec_load_prompt_file(cparams.negative_prompts_file); + std::vector positive_prompts = ctrlvec_load_prompt_file(cparams.positive_prompts_file, false, cparams.single_prompt); + std::vector negative_prompts = ctrlvec_load_prompt_file(cparams.negative_prompts_file, false, cparams.single_prompt); if (positive_prompts.size() != negative_prompts.size()) { fprintf(stderr, "number of positive and negative prompts must be equal\n"); return 1;