You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
35 lines
912 B
35 lines
912 B
import argparse
|
|
import os
|
|
from metrics import PromptDSTEvaluator
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument('-o', '--output_file', default=None, type=str, required=True,
|
|
help="The path of the outputs JSON file <JSON Path>.")
|
|
|
|
# parse the arguments
|
|
args = parser.parse_args()
|
|
|
|
if args.output_file is None:
|
|
print('No output file provided for evaluation!')
|
|
return
|
|
|
|
# Assertion check for file availability
|
|
assert os.path.isfile(args.output_file)
|
|
|
|
# create an evaluator instance for Prompt-based DST
|
|
evaluator = PromptDSTEvaluator(args.output_file)
|
|
|
|
# compute Joint Goal Accuracy
|
|
evaluator.compute_joint_goal_accuracy()
|
|
|
|
# compute JGA for the values that are correctly extracted in a turn
|
|
evaluator.compute_jga_for_correct_values()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|