Added output files and results in the README

main
Pavan Mandava 3 years ago
parent 7b5c9920fb
commit 1ac5d6ed15

@ -231,24 +231,25 @@ Example: `sh test_prompting.sh -m 50-dpd/experiment-20221003T172424/epoch-09`
The generated belief states (outputs) are saved under `OUTPUTS_DIR_PROMPT` folder. Some of the best outputs are uploaded to this repository and can be found under [outputs](outputs) folder. The generated belief states (outputs) are saved under `OUTPUTS_DIR_PROMPT` folder. Some of the best outputs are uploaded to this repository and can be found under [outputs](outputs) folder.
### Prompting Evaluation ### Prompting Evaluation
The standard Joint Goal Accuracy (JGA) is used to evaluate the belief predictions. The standard Joint Goal Accuracy (**JGA**) is used to evaluate the belief state predictions. In order to exclude the influence of wrongly extracted values, **JGA*** is computed only for values that are extracted correctly at each turn.
Edit the [evaluate.py](prompt-learning/evaluate.py) to set the predictions output file before running the evaluation The [evaluate.py](prompt-learning/evaluate.py) file can be used to verify the below JGA scores.
```shell ```shell
python evaluate.py cd prompt-learning
python evaluate.py -o path/to/outputs/file
``` ```
### Results from prompt-based belief state generations ### Results from prompt-based belief state generations
|data-split| JGA* | |data-split| JGA | JGA* |
|--|:--:| |--|:--:|:--:|
| 5-dpd | //TODO | | 5-dpd | 30.66 | 71.04 |
| 10-dpd | //TODO | | 10-dpd | 42.65 | 86.43 |
| 50-dpd | //TODO | | 50-dpd | 47.06 | 91.63 |
| 100-dpd | //TODO | | 100-dpd | 47.74 | 92.31 |
| 125-dpd | //TODO | | 125-dpd | 46.49 | 91.86 |
| 250-dpd | //TODO | | 250-dpd | 47.06 | 92.08 |
// TODO :: Add prompt-based outputs and results in the above table > **Note:** All the generated output files for the above reported results are available in the repository. Check [outputs/prompt-learning](outputs/prompt-learning) directory to see the output JSON files for each data-split.
## Multi-prompt Learning Experiments ## Multi-prompt Learning Experiments

@ -119,7 +119,7 @@ class BaselineDSTEvaluator:
print('Evaluation :: Joint Goal Accuracy = ', (correctly_predicted / total_turns) * 100) print('Evaluation :: Joint Goal Accuracy = ', (correctly_predicted / total_turns) * 100)
evaluator = BaselineDSTEvaluator('../outputs/baseline/50-dpd/checkpoint-55000/output_test.json', evaluator = BaselineDSTEvaluator('../outputs/baseline/50-dpd/output_test.json',
'../data/baseline/test/test.soloist.json') '../data/baseline/test/test.soloist.json')
predicted_belief_states = evaluator.parse_prediction_belief_states() predicted_belief_states = evaluator.parse_prediction_belief_states()
true_belief_states = evaluator.parse_true_belief_states() true_belief_states = evaluator.parse_true_belief_states()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,34 @@
import argparse
import os
from metrics import PromptDSTEvaluator
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument('-o', '--output_file', default=None, type=str, required=True,
help="The path of the outputs JSON file <JSON Path>.")
# parse the arguments
args = parser.parse_args()
if args.output_file is None:
print('No output file provided for evaluation!')
return
# Assertion check for file availability
assert os.path.isfile(args.output_file)
# create an evaluator instance for Prompt-based DST
evaluator = PromptDSTEvaluator(args.output_file)
# compute Joint Goal Accuracy
evaluator.compute_joint_goal_accuracy()
# compute JGA for the values that are correctly extracted in a turn
evaluator.compute_jga_for_correct_values()
if __name__ == "__main__":
main()

@ -50,10 +50,10 @@ class PromptDSTEvaluator:
def compute_jga_for_correct_values(self, no_print=False): def compute_jga_for_correct_values(self, no_print=False):
if not no_print: if not no_print:
print('Computing Joint Goal Accuracy metric only where values are extracted correctly!') print('Computing Joint Goal Accuracy metric for the values that are extracted correctly! (turn-level)')
if len(self.true_states_list) != len(self.gen_states_list): if len(self.true_states_list) != len(self.gen_states_list):
raise ValueError('Length mismatch!') raise ValueError('Unable to compute the metric. Length mismatch in the outputs!')
# keep a count for computing JGA # keep a count for computing JGA
correctly_generated, total_turns = 0, 0 correctly_generated, total_turns = 0, 0

Loading…
Cancel
Save