Skip to content

feat: In cli.py and evaluator.py, allow output UTF-8 (for non-English content) and custom indentation #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ragchecker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def get_args():
"--output_path", type=str, required=True,
help="Output path to the result json file."
)
parser.add_argument(
"--ensure_ascii", type=bool, default=True,
help="Whether to ensure ascii characters in output."
" (Set this to False if you are processing non-English content)"
)
parser.add_argument(
"--indent", type=int, default=2,
help="Set JSON output indent."
)
parser.add_argument(
'--extractor_name', type=str, default="bedrock/meta.llama3-70b-instruct-v1:0",
help="Model used for extracting claims. Default: bedrock/meta.llama3-70b-instruct-v1:0"
Expand Down Expand Up @@ -45,7 +54,6 @@ def get_args():
"--batch_size_checker", type=int, default=32,
help="Batch size for checker."
)

# checking options
parser.add_argument(
'--metrics', type=str, nargs='+', default=[all_metrics],
Expand Down Expand Up @@ -83,10 +91,10 @@ def main():
)
with open(args.input_path, "r") as f:
rag_results = RAGResults.from_json(f.read())
evaluator.evaluate(rag_results, metrics=args.metrics, save_path=args.output_path)
print(json.dumps(rag_results.metrics, indent=2))
evaluator.evaluate(rag_results, metrics=args.metrics, save_path=args.output_path, ensure_ascii=args.ensure_ascii)
print(json.dumps(rag_results.metrics, indent=args.indent, ensure_ascii=args.ensure_ascii))
with open(args.output_path, "w") as f:
f.write(rag_results.to_json(indent=2))
f.write(rag_results.to_json(indent=args.indent, ensure_ascii=args.ensure_ascii))


if __name__ == "__main__":
Expand Down
34 changes: 17 additions & 17 deletions ragchecker/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def __init__(
self.joint_check = joint_check
self.joint_check_num = joint_check_num
self.kwargs = kwargs

self.sagemaker_client = sagemaker_client
self.sagemaker_params = sagemaker_params
self.sagemaker_get_response_func = sagemaker_get_response_func

self.custom_llm_api_func = custom_llm_api_func

self.extractor = LLMExtractor(
model=extractor_name,
model=extractor_name,
batch_size=batch_size_extractor,
api_base=extractor_api_base
)
Expand All @@ -81,11 +81,11 @@ def __init__(
self.checker = AlignScoreChecker(batch_size=batch_size_checker)
else:
self.checker = LLMChecker(
model=checker_name,
model=checker_name,
batch_size=batch_size_checker,
api_base=checker_api_base
)

def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"):
"""
Extract claims from the response and ground truth answer.
Expand All @@ -99,7 +99,7 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"):
"""
assert extract_type in ["gt_answer", "response"], \
"extract_type should be either 'gt_answer' or 'response'."

if extract_type == "gt_answer":
results = [ret for ret in results if ret.gt_answer_claims is None]
texts = [result.gt_answer for result in results]
Expand All @@ -109,7 +109,7 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"):
if not results:
return
questions = [result.query for result in results]

logger.info(f"Extracting claims for {extract_type} of {len(results)} RAG results.")
extraction_results = self.extractor.extract(
batch_responses=texts,
Expand Down Expand Up @@ -194,8 +194,8 @@ def check_claims(self, results: RAGResults, check_type="answer2response"):
result.retrieved2answer = checking_results[i]
else:
result.retrieved2response = checking_results[i]
def evaluate(self, results: RAGResults, metrics=all_metrics, save_path=None):

def evaluate(self, results: RAGResults, metrics=all_metrics, save_path=None, indent=2, ensure_ascii=True):
"""
Evaluate the RAG results.

Expand All @@ -207,7 +207,7 @@ def evaluate(self, results: RAGResults, metrics=all_metrics, save_path=None):
List of metrics to compute. Default: 'all'.
save_path : str, optional
Path to save the results. Default: None. Will perform progress checkpointing if provided.
"""
"""
# identify the metrics and required intermediate results
if isinstance(metrics, str):
metrics = [metrics]
Expand All @@ -222,19 +222,19 @@ def evaluate(self, results: RAGResults, metrics=all_metrics, save_path=None):
ret_metrics.add(metric)
for metric in ret_metrics:
requirements.update(METRIC_REQUIREMENTS[metric])

# compute the required intermediate results
for requirement in requirements:
self.check_claims(results, check_type=requirement)
if save_path is not None:
with open(save_path, "w") as f:
f.write(results.to_json(indent=2))
f.write(results.to_json(indent=indent, ensure_ascii=ensure_ascii))

# compute the metrics
for metric in ret_metrics:
for result in results.results:
METRIC_FUNC_MAP[metric](result)

# aggregate the metrics
for group, group_metrics in METRIC_GROUP_MAP.items():
if group == all_metrics:
Expand All @@ -244,10 +244,10 @@ def evaluate(self, results: RAGResults, metrics=all_metrics, save_path=None):
results.metrics[group][metric] = round(np.mean(
[result.metrics[metric] for result in results.results]
) * 100, 1)
# save the results

# save the results
if save_path is not None:
with open(save_path, "w") as f:
f.write(results.to_json(indent=2))
f.write(results.to_json(indent=indent, ensure_ascii=ensure_ascii))

return results.metrics