Added prompt-training files

main
Pavan Mandava 3 years ago
parent d9e421cb2b
commit 9be77c2de6

2
.gitignore vendored

@ -181,4 +181,6 @@ ipython_config.py
# runs folder # runs folder
baseline/runs/ baseline/runs/
.local
utils/stanford-corenlp

@ -119,7 +119,7 @@ class BaselineDSTEvaluator:
print('Evaluation :: Joint Goal Accuracy = ', (correctly_predicted / total_turns) * 100) print('Evaluation :: Joint Goal Accuracy = ', (correctly_predicted / total_turns) * 100)
evaluator = BaselineDSTEvaluator('../outputs/baseline/experiment-20220829/50-dpd/checkpoint-55000/output_test.json', evaluator = BaselineDSTEvaluator('../outputs/baseline/50-dpd/checkpoint-55000/output_test.json',
'../data/baseline/test/test.soloist.json') '../data/baseline/test/test.soloist.json')
predicted_belief_states = evaluator.parse_prediction_belief_states() predicted_belief_states = evaluator.parse_prediction_belief_states()
true_belief_states = evaluator.parse_true_belief_states() true_belief_states = evaluator.parse_true_belief_states()

@ -0,0 +1 @@
["SNG1031.json", "SNG1031.json", "SNG1031.json", "SNG1031.json", "SNG1031.json", "SSNG0241.json", "SSNG0241.json", "SSNG0241.json", "SSNG0241.json", "SSNG0241.json", "SSNG0241.json", "SNG1272.json", "SNG1272.json", "SNG1272.json", "SNG0310.json", "SNG0310.json", "SNG0310.json", "SNG0310.json", "SNG0310.json", "SNG1034.json", "SNG1034.json", "SNG1034.json", "SNG1034.json", "SNG1034.json", "SNG1034.json", "SSNG0289.json", "SSNG0289.json", "SSNG0289.json", "SSNG0289.json", "SSNG0289.json", "SSNG0289.json", "SSNG0289.json", "SSNG0289.json", "SNG1277.json", "SNG1277.json", "SNG1277.json", "SNG1277.json", "SNG1277.json", "SNG0427.json", "SNG0427.json", "SNG0427.json", "SNG0427.json", "SNG0427.json", "SNG0868.json", "SNG0868.json", "SNG0868.json", "SNG0868.json", "WOZ20294.json", "WOZ20294.json", "WOZ20294.json", "WOZ20294.json", "WOZ20294.json", "WOZ20294.json", "SSNG0332.json", "SSNG0332.json", "SSNG0332.json", "SSNG0332.json", "SSNG0332.json", "SSNG0332.json", "SNG1251.json", "SNG1251.json", "SNG1251.json", "SNG1342.json", "SNG1342.json", "SNG1342.json", "SNG1342.json", "SNG1342.json", "SNG1342.json", "SNG02069.json", "SNG02069.json", "SNG02069.json", "SNG02069.json", "SNG02069.json", "SSNG0229.json", "SSNG0229.json", "SSNG0229.json", "SSNG0229.json", "SSNG0229.json", "SSNG0229.json", "SSNG0229.json", "SNG01238.json", "SNG01238.json", "SNG01338.json", "SNG01338.json", "SNG01338.json", "SNG01338.json", "SNG1080.json", "SNG1080.json", "SNG1080.json", "SNG1080.json", "SNG1366.json", "SNG1366.json", "SNG1366.json", "SNG1366.json", "SNG0460.json", "SNG0460.json", "SNG0460.json", "SNG0460.json", "SNG0460.json", "SNG1149.json", "SNG1149.json", "SNG1149.json", "SNG02173.json", "SNG02173.json", "SNG02173.json", "SNG02173.json", "SNG02173.json", "SNG02173.json", "SNG02173.json", "SNG0097.json", "SNG0097.json", "SNG0097.json", "SNG1273.json", "SNG1273.json", "SNG1273.json", "SNG1073.json", "SNG1073.json", "WOZ20587.json", "WOZ20587.json", "WOZ20587.json", "SNG1218.json", "SNG1218.json", "SNG1218.json", "SNG02350.json", "SNG02350.json", "SNG02350.json", "SNG0431.json", "SNG0431.json", "SNG0431.json", "SNG0431.json", "SNG0431.json", "SNG1137.json", "SNG1137.json", "SNG1137.json", "SNG1137.json", "SNG1137.json", "SNG1132.json", "SNG1132.json", "SNG1132.json", "SNG1132.json", "WOZ20353.json", "WOZ20353.json", "WOZ20353.json", "WOZ20353.json", "SNG1104.json", "SNG1104.json", "SNG1104.json", "SNG1104.json", "SNG1104.json", "SNG1187.json", "SNG1187.json", "SNG1187.json", "SNG1187.json", "WOZ20674.json", "WOZ20674.json", "WOZ20674.json", "WOZ20674.json", "SNG02348.json", "SNG02348.json", "SNG02348.json", "SNG02348.json", "SNG02348.json", "SNG0014.json", "SNG0014.json", "SNG0014.json", "SNG0014.json", "SNG0014.json", "SNG0014.json", "WOZ20471.json", "WOZ20471.json", "WOZ20471.json", "WOZ20471.json", "SNG0436.json", "SNG0436.json", "SNG0436.json", "SNG0436.json", "SNG0436.json", "SNG0436.json", "SNG0436.json", "SNG0346.json", "SNG0346.json", "SNG0346.json", "SNG0346.json", "SNG0346.json", "SSNG0055.json", "SSNG0055.json", "SSNG0055.json", "SSNG0055.json", "SSNG0055.json", "SSNG0055.json", "SNG1145.json", "SNG1145.json", "SNG1145.json", "SNG1145.json", "SNG01424.json", "SNG01424.json", "SNG01424.json", "SNG0262.json", "SNG0262.json", "SNG0262.json", "SNG0262.json", "SNG0262.json", "WOZ20123.json", "WOZ20123.json", "WOZ20123.json", "SNG0303.json", "SNG0303.json", "SNG0303.json", "SNG0303.json", "SNG0303.json", "SNG0303.json", "SNG0705.json", "SNG0705.json", "SNG0705.json", "SNG0705.json", "SNG0705.json", "SNG1331.json", "SNG1331.json", "SNG1331.json", "SNG1331.json", "SNG1331.json", "SNG1331.json", "SSNG0202.json", "SSNG0202.json", "SSNG0202.json", "SSNG0202.json", "SSNG0202.json", "SNG0334.json", "SNG0334.json", "SNG0334.json", "SNG0334.json", "SNG0334.json", "SNG0334.json", "SNG0334.json"]

File diff suppressed because it is too large Load Diff

@ -0,0 +1 @@
["SNG0359.json", "SNG0359.json", "SNG0359.json", "SNG0359.json", "SNG1100.json", "SNG1100.json", "SNG1100.json", "SNG1100.json", "SNG1297.json", "SNG1297.json", "SNG1297.json", "SNG1107.json", "SNG1107.json", "SNG1107.json", "SNG1107.json", "WOZ20560.json", "WOZ20560.json", "SSNG0342.json", "SSNG0342.json", "SSNG0342.json", "WOZ20249.json", "WOZ20249.json", "WOZ20249.json", "WOZ20297.json", "WOZ20297.json", "WOZ20297.json", "SNG1163.json", "SNG1163.json", "SNG1163.json", "SNG1163.json", "SNG0410.json", "SNG0410.json", "SNG0410.json", "SNG0410.json", "SNG0877.json", "SNG0877.json", "SNG0877.json", "SNG0877.json", "SNG0877.json", "WOZ20570.json", "WOZ20570.json", "WOZ20570.json", "WOZ20570.json", "WOZ20570.json", "WOZ20570.json", "SNG0886.json", "SNG0886.json", "SNG0886.json", "SNG0886.json", "SNG01181.json", "SNG01181.json", "SNG01181.json", "SNG0352.json", "SNG0352.json", "SNG0352.json", "SNG0352.json", "SNG0264.json", "SNG0264.json", "SNG0264.json", "SNG0264.json", "SNG0264.json", "SNG0264.json", "SNG01823.json", "SNG01823.json", "SNG01823.json", "SNG01823.json", "SNG01823.json", "SNG01552.json", "SNG01552.json", "SNG01552.json", "SNG1355.json", "SNG1355.json", "SNG1355.json", "SNG1336.json", "SNG1336.json", "SNG1336.json", "SSNG0207.json", "SSNG0207.json", "SSNG0207.json", "SSNG0207.json", "SSNG0207.json", "SSNG0207.json", "SNG0066.json", "SNG0066.json", "SNG0066.json", "SNG0066.json", "SSNG0032.json", "SSNG0032.json", "SSNG0032.json", "SSNG0032.json", "SSNG0032.json", "SSNG0032.json", "SNG0777.json", "SNG0777.json", "SNG0777.json", "SNG0777.json", "SNG1129.json", "SNG1129.json", "SNG1129.json", "SNG1129.json"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -2812,8 +2812,7 @@
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = london kings cross", "departure = london kings cross"
"people = none"
] ]
}, },
{ {
@ -2831,8 +2830,7 @@
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"arrive = 14:00", "arrive = 14:00",
"departure = london kings cross", "departure = london kings cross"
"people = none"
] ]
}, },
{ {
@ -4167,8 +4165,7 @@
], ],
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"arrive = 20:30", "arrive = 20:30"
"people = none"
] ]
}, },
{ {
@ -4183,8 +4180,7 @@
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30"
"people = none"
] ]
}, },
{ {
@ -4202,8 +4198,7 @@
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30",
"departure = leicester", "departure = leicester"
"people = none"
] ]
}, },
{ {
@ -4224,8 +4219,7 @@
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30",
"departure = leicester", "departure = leicester"
"people = none"
] ]
}, },
{ {
@ -4586,9 +4580,7 @@
"belief_states": [ "belief_states": [
"leave = 17:15", "leave = 17:15",
"destination = london liverpool street", "destination = london liverpool street",
"day = friday", "day = friday"
"arrive = none",
"departure = none"
] ]
}, },
{ {
@ -4606,7 +4598,6 @@
"leave = 17:15", "leave = 17:15",
"destination = london liverpool street", "destination = london liverpool street",
"day = friday", "day = friday",
"departure = none",
"people = 8" "people = 8"
] ]
}, },
@ -4627,7 +4618,6 @@
"leave = 17:15", "leave = 17:15",
"destination = london liverpool street", "destination = london liverpool street",
"day = friday", "day = friday",
"departure = none",
"people = 8" "people = 8"
] ]
}, },
@ -5165,8 +5155,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"name = kings college", "name = kings college"
"area = none"
] ]
}, },
{ {
@ -5181,8 +5170,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"name = kings college", "name = kings college"
"area = none"
] ]
}, },
{ {
@ -5972,7 +5960,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"parking = yes", "parking = yes",
"stars = 0", "stars = 0",
"type = dont care" "type = dont care"
@ -5988,7 +5975,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = east", "area = east",
"parking = yes", "parking = yes",
"stars = 0", "stars = 0",
@ -8069,7 +8055,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"day = wednesday", "day = wednesday",
"arrive = 16:00" "arrive = 16:00"
] ]
@ -11236,8 +11221,7 @@
"destination = stevenage", "destination = stevenage",
"day = sunday", "day = sunday",
"arrive = 09:00", "arrive = 09:00",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -13408,8 +13392,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 08:45", "leave = 08:45",
"destination = london kings cross", "destination = london kings cross"
"departure = none"
] ]
}, },
{ {
@ -13424,8 +13407,7 @@
"belief_states": [ "belief_states": [
"leave = 08:45", "leave = 08:45",
"destination = london kings cross", "destination = london kings cross",
"day = friday", "day = friday"
"departure = none"
] ]
}, },
{ {
@ -13529,8 +13511,7 @@
"belief_states": [ "belief_states": [
"day = thursday", "day = thursday",
"arrive = 09:15", "arrive = 09:15",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -13549,8 +13530,7 @@
"belief_states": [ "belief_states": [
"day = thursday", "day = thursday",
"arrive = 09:15", "arrive = 09:15",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -14487,9 +14467,7 @@
"domains": [ "domains": [
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": []
"destination = none"
]
}, },
{ {
"history": [ "history": [
@ -14504,9 +14482,7 @@
"domains": [ "domains": [
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": []
"destination = none"
]
}, },
{ {
"history": [ "history": [
@ -16637,7 +16613,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = west", "area = west",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
@ -16665,7 +16640,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = west", "area = west",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
@ -16695,7 +16669,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = west", "area = west",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
@ -17318,9 +17291,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -17333,9 +17304,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -17350,9 +17319,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -18976,8 +18943,7 @@
], ],
"belief_states": [ "belief_states": [
"food = bistro", "food = bistro",
"price = moderate", "price = moderate"
"name = none"
] ]
}, },
{ {
@ -18991,8 +18957,7 @@
], ],
"belief_states": [ "belief_states": [
"food = british", "food = british",
"price = moderate", "price = moderate"
"name = none"
] ]
}, },
{ {
@ -19009,7 +18974,6 @@
"belief_states": [ "belief_states": [
"food = british", "food = british",
"price = moderate", "price = moderate",
"name = none",
"area = dont care", "area = dont care",
"day = thursday", "day = thursday",
"people = 2", "people = 2",
@ -21311,7 +21275,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"price = expensive", "price = expensive",
"stars = 4", "stars = 4",
"internet = yes", "internet = yes",
@ -25059,7 +25022,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"arrive = 18:00", "arrive = 18:00",
"departure = cambridge" "departure = cambridge"
] ]
@ -28264,7 +28226,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = london kings cross train station", "destination = london kings cross train station",
"departure = museum of classical archaeology" "departure = museum of classical archaeology"
] ]
@ -28605,8 +28566,7 @@
"destination = stevenage", "destination = stevenage",
"day = saturday", "day = saturday",
"arrive = 12:45", "arrive = 12:45",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -28626,8 +28586,7 @@
"destination = stevenage", "destination = stevenage",
"day = saturday", "day = saturday",
"arrive = 12:45", "arrive = 12:45",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -28649,8 +28608,7 @@
"destination = stevenage", "destination = stevenage",
"day = saturday", "day = saturday",
"arrive = 12:45", "arrive = 12:45",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -28674,8 +28632,7 @@
"destination = stevenage", "destination = stevenage",
"day = saturday", "day = saturday",
"arrive = 12:45", "arrive = 12:45",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -29573,7 +29530,6 @@
"belief_states": [ "belief_states": [
"price = cheap", "price = cheap",
"internet = yes", "internet = yes",
"type = none",
"day = friday", "day = friday",
"people = 8", "people = 8",
"stay = 5" "stay = 5"
@ -29598,7 +29554,6 @@
"area = east", "area = east",
"price = cheap", "price = cheap",
"internet = yes", "internet = yes",
"type = none",
"day = friday", "day = friday",
"people = 8", "people = 8",
"stay = 5" "stay = 5"
@ -29626,7 +29581,6 @@
"area = east", "area = east",
"price = cheap", "price = cheap",
"internet = yes", "internet = yes",
"type = none",
"day = friday", "day = friday",
"people = 8", "people = 8",
"stay = 5" "stay = 5"
@ -29656,7 +29610,6 @@
"area = east", "area = east",
"price = cheap", "price = cheap",
"internet = yes", "internet = yes",
"type = none",
"day = friday", "day = friday",
"people = 8", "people = 8",
"stay = 5" "stay = 5"
@ -29688,7 +29641,6 @@
"area = east", "area = east",
"price = cheap", "price = cheap",
"internet = yes", "internet = yes",
"type = none",
"day = friday", "day = friday",
"people = 8", "people = 8",
"stay = 5" "stay = 5"
@ -29722,7 +29674,6 @@
"area = east", "area = east",
"price = cheap", "price = cheap",
"internet = yes", "internet = yes",
"type = none",
"day = friday", "day = friday",
"people = 8", "people = 8",
"stay = 5" "stay = 5"
@ -30615,7 +30566,6 @@
], ],
"belief_states": [ "belief_states": [
"name = limehouse", "name = limehouse",
"price = none",
"day = thursday", "day = thursday",
"people = 6", "people = 6",
"stay = 2" "stay = 2"
@ -30636,7 +30586,6 @@
], ],
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = none",
"day = thursday", "day = thursday",
"people = 6", "people = 6",
"stay = 2" "stay = 2"
@ -30659,7 +30608,6 @@
], ],
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = none",
"day = thursday", "day = thursday",
"people = 6", "people = 6",
"stay = 2" "stay = 2"
@ -30684,7 +30632,6 @@
], ],
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = none",
"day = thursday", "day = thursday",
"people = 6", "people = 6",
"stay = 2" "stay = 2"
@ -30711,7 +30658,6 @@
], ],
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = none",
"day = thursday", "day = thursday",
"people = 6", "people = 6",
"stay = 2" "stay = 2"
@ -30844,7 +30790,6 @@
], ],
"belief_states": [ "belief_states": [
"destination = la margherita", "destination = la margherita",
"departure = none",
"arrive = 13:45" "arrive = 13:45"
] ]
}, },
@ -30861,7 +30806,6 @@
], ],
"belief_states": [ "belief_states": [
"destination = la margherita", "destination = la margherita",
"departure = none",
"arrive = 13:45" "arrive = 13:45"
] ]
}, },
@ -30927,7 +30871,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = london kings cross", "destination = london kings cross",
"day = wednesday", "day = wednesday",
"arrive = 18:00" "arrive = 18:00"
@ -33227,10 +33170,7 @@
"area = centre", "area = centre",
"stars = 0", "stars = 0",
"internet = yes", "internet = yes",
"type = dont care", "type = dont care"
"day = none",
"people = none",
"stay = none"
] ]
}, },
{ {
@ -33251,10 +33191,7 @@
"area = centre", "area = centre",
"stars = 0", "stars = 0",
"internet = yes", "internet = yes",
"type = dont care", "type = dont care"
"day = none",
"people = none",
"stay = none"
] ]
}, },
{ {
@ -33276,10 +33213,7 @@
"area = centre", "area = centre",
"stars = 0", "stars = 0",
"internet = yes", "internet = yes",
"type = dont care", "type = dont care"
"day = none",
"people = none",
"stay = none"
] ]
}, },
{ {
@ -33859,7 +33793,6 @@
], ],
"belief_states": [ "belief_states": [
"leave = 01:45", "leave = 01:45",
"destination = none",
"departure = rice boat" "departure = rice boat"
] ]
}, },
@ -34501,8 +34434,7 @@
], ],
"belief_states": [ "belief_states": [
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -34518,8 +34450,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -34537,8 +34468,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -34558,8 +34488,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -34581,8 +34510,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -34606,8 +34534,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -35529,7 +35456,6 @@
], ],
"belief_states": [ "belief_states": [
"price = expensive", "price = expensive",
"name = none",
"area = centre" "area = centre"
] ]
}, },
@ -36077,7 +36003,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"departure = thanh binh" "departure = thanh binh"
] ]
}, },
@ -36091,7 +36016,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = prezzo", "destination = prezzo",
"departure = thanh binh" "departure = thanh binh"
] ]
@ -36108,7 +36032,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = prezzo", "destination = prezzo",
"departure = thanh binh" "departure = thanh binh"
] ]
@ -39313,8 +39236,7 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"food = european", "food = european"
"price = none"
] ]
}, },
{ {

@ -1666,7 +1666,6 @@
], ],
"belief_states": [ "belief_states": [
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"internet = yes" "internet = yes"
] ]
@ -1691,7 +1690,6 @@
"belief_states": [ "belief_states": [
"name = huntingdon marriott hotel", "name = huntingdon marriott hotel",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"internet = yes" "internet = yes"
] ]
@ -1718,7 +1716,6 @@
"belief_states": [ "belief_states": [
"name = huntingdon marriott hotel", "name = huntingdon marriott hotel",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"internet = yes" "internet = yes"
] ]
@ -1747,7 +1744,6 @@
"belief_states": [ "belief_states": [
"name = huntingdon marriott hotel", "name = huntingdon marriott hotel",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"internet = yes" "internet = yes"
] ]
@ -3024,7 +3020,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = london liverpool street", "destination = london liverpool street",
"departure = cambridge" "departure = cambridge"
] ]
@ -3045,7 +3040,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = london liverpool street", "destination = london liverpool street",
"departure = cambridge" "departure = cambridge"
] ]
@ -3252,7 +3246,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = none",
"stars = 4" "stars = 4"
] ]
}, },
@ -3270,7 +3263,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"day = friday", "day = friday",
"people = 4", "people = 4",
@ -3293,7 +3285,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"day = friday", "day = friday",
"people = 4", "people = 4",
@ -3318,7 +3309,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"day = friday", "day = friday",
"people = 7", "people = 7",
@ -3345,7 +3335,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = none",
"stars = 4", "stars = 4",
"day = friday", "day = friday",
"people = 7", "people = 7",
@ -5334,8 +5323,7 @@
], ],
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"arrive = 20:30", "arrive = 20:30"
"people = none"
] ]
}, },
{ {
@ -5350,8 +5338,7 @@
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30"
"people = none"
] ]
}, },
{ {
@ -5369,8 +5356,7 @@
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30",
"departure = leicester", "departure = leicester"
"people = none"
] ]
}, },
{ {
@ -5391,8 +5377,7 @@
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30",
"departure = leicester", "departure = leicester"
"people = none"
] ]
}, },
{ {
@ -6795,8 +6780,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"name = kings college", "name = kings college"
"area = none"
] ]
}, },
{ {
@ -6811,8 +6795,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"name = kings college", "name = kings college"
"area = none"
] ]
}, },
{ {
@ -7708,8 +7691,7 @@
"leave = 17:15", "leave = 17:15",
"destination = london kings cross", "destination = london kings cross",
"day = monday", "day = monday",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -8379,13 +8361,7 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = arbury lodge guesthouse", "name = arbury lodge guesthouse"
"area = none",
"parking = none",
"price = none",
"stars = none",
"internet = none",
"type = none"
] ]
}, },
{ {
@ -8399,12 +8375,6 @@
], ],
"belief_states": [ "belief_states": [
"name = arbury lodge guesthouse", "name = arbury lodge guesthouse",
"area = none",
"parking = none",
"price = none",
"stars = none",
"internet = none",
"type = none",
"day = friday", "day = friday",
"people = 4", "people = 4",
"stay = 2" "stay = 2"
@ -8423,12 +8393,6 @@
], ],
"belief_states": [ "belief_states": [
"name = arbury lodge guesthouse", "name = arbury lodge guesthouse",
"area = none",
"parking = none",
"price = none",
"stars = none",
"internet = none",
"type = none",
"day = friday", "day = friday",
"people = 4", "people = 4",
"stay = 1" "stay = 1"
@ -8449,12 +8413,6 @@
], ],
"belief_states": [ "belief_states": [
"name = arbury lodge guesthouse", "name = arbury lodge guesthouse",
"area = none",
"parking = none",
"price = none",
"stars = none",
"internet = none",
"type = none",
"day = friday", "day = friday",
"people = 4", "people = 4",
"stay = 1" "stay = 1"
@ -9413,7 +9371,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"day = wednesday", "day = wednesday",
"arrive = 16:00" "arrive = 16:00"
] ]
@ -11470,7 +11427,6 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"price = none",
"name = dont care", "name = dont care",
"area = centre", "area = centre",
"day = sunday", "day = sunday",
@ -11495,7 +11451,6 @@
], ],
"belief_states": [ "belief_states": [
"food = dont care", "food = dont care",
"price = none",
"name = dont care", "name = dont care",
"area = centre", "area = centre",
"day = sunday", "day = sunday",
@ -11522,7 +11477,6 @@
], ],
"belief_states": [ "belief_states": [
"food = dont care", "food = dont care",
"price = none",
"name = dont care", "name = dont care",
"area = centre", "area = centre",
"day = sunday", "day = sunday",
@ -15335,8 +15289,7 @@
"belief_states": [ "belief_states": [
"day = thursday", "day = thursday",
"arrive = 09:15", "arrive = 09:15",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -15355,8 +15308,7 @@
"belief_states": [ "belief_states": [
"day = thursday", "day = thursday",
"arrive = 09:15", "arrive = 09:15",
"departure = cambridge", "departure = cambridge"
"people = none"
] ]
}, },
{ {
@ -15528,7 +15480,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = birmingham new street", "destination = birmingham new street",
"day = thursday", "day = thursday",
"arrive = 17:15", "arrive = 17:15",
@ -15543,7 +15494,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"day = monday", "day = monday",
"departure = stevenage" "departure = stevenage"
] ]
@ -16374,7 +16324,6 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"price = none",
"name = chiquito restaurant bar", "name = chiquito restaurant bar",
"day = tuesday", "day = tuesday",
"people = 6", "people = 6",
@ -16521,7 +16470,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"price = cheap" "price = cheap"
] ]
@ -16536,7 +16484,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"price = cheap", "price = cheap",
"stars = 4", "stars = 4",
@ -16634,7 +16581,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = cheap", "price = cheap",
"stars = none",
"day = thursday", "day = thursday",
"people = 5", "people = 5",
"stay = 5" "stay = 5"
@ -16662,7 +16608,6 @@
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = cheap", "price = cheap",
"stars = none",
"day = thursday", "day = thursday",
"people = 5", "people = 5",
"stay = 5" "stay = 5"
@ -16880,7 +16825,6 @@
], ],
"belief_states": [ "belief_states": [
"food = european", "food = european",
"name = none",
"area = centre", "area = centre",
"day = wednesday", "day = wednesday",
"people = 5", "people = 5",
@ -19246,7 +19190,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"stars = 4" "stars = 4"
] ]
@ -19263,7 +19206,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"stars = 4", "stars = 4",
"day = thursday", "day = thursday",
@ -19286,7 +19228,6 @@
], ],
"belief_states": [ "belief_states": [
"area = north", "area = north",
"price = none",
"stars = 4", "stars = 4",
"day = thursday", "day = thursday",
"people = 3", "people = 3",
@ -19309,7 +19250,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"stars = 4", "stars = 4",
"day = thursday", "day = thursday",
@ -19335,7 +19275,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"stars = 4", "stars = 4",
"day = thursday", "day = thursday",
@ -19363,7 +19302,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"stars = 4", "stars = 4",
"day = thursday", "day = thursday",
@ -20371,9 +20309,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -20386,9 +20322,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -20403,9 +20337,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -23468,7 +23400,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"price = moderate", "price = moderate",
"internet = yes" "internet = yes"
] ]
@ -24059,7 +23990,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
@ -24084,7 +24014,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = north", "area = north",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
@ -25223,7 +25152,6 @@
"belief_states": [ "belief_states": [
"food = chinese", "food = chinese",
"price = cheap", "price = cheap",
"name = none",
"area = centre", "area = centre",
"day = wednesday", "day = wednesday",
"people = 5", "people = 5",
@ -25250,8 +25178,6 @@
"belief_states": [ "belief_states": [
"food = chinese", "food = chinese",
"price = cheap", "price = cheap",
"name = none",
"area = none",
"day = wednesday", "day = wednesday",
"people = 5", "people = 5",
"time = 19:30" "time = 19:30"
@ -25279,7 +25205,6 @@
"belief_states": [ "belief_states": [
"food = asian oriental", "food = asian oriental",
"price = cheap", "price = cheap",
"area = none",
"day = wednesday", "day = wednesday",
"people = 5", "people = 5",
"time = 19:30" "time = 19:30"
@ -25310,7 +25235,6 @@
"food = asian oriental", "food = asian oriental",
"price = cheap", "price = cheap",
"name = dojo noodle bar", "name = dojo noodle bar",
"area = none",
"day = wednesday", "day = wednesday",
"people = 5", "people = 5",
"time = 19:30" "time = 19:30"
@ -25343,7 +25267,6 @@
"food = asian oriental", "food = asian oriental",
"price = cheap", "price = cheap",
"name = dojo noodle bar", "name = dojo noodle bar",
"area = none",
"day = wednesday", "day = wednesday",
"people = 5", "people = 5",
"time = 19:30" "time = 19:30"
@ -26242,7 +26165,6 @@
], ],
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = none",
"departure = la margherita" "departure = la margherita"
] ]
}, },
@ -26257,9 +26179,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = none", "departure = la margherita"
"departure = la margherita",
"arrive = none"
] ]
}, },
{ {
@ -26275,9 +26195,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = none", "departure = la margherita"
"departure = la margherita",
"arrive = none"
] ]
}, },
{ {
@ -26296,8 +26214,7 @@
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = la margherita", "destination = la margherita",
"departure = avalon", "departure = avalon"
"arrive = none"
] ]
}, },
{ {
@ -27805,8 +27722,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 01:45", "leave = 01:45",
"destination = da vinci pizzeria", "destination = da vinci pizzeria"
"arrive = none"
] ]
}, },
{ {
@ -28114,7 +28030,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"price = cheap", "price = cheap",
"day = tuesday", "day = tuesday",
"people = 6", "people = 6",
@ -28135,7 +28050,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = dont care", "area = dont care",
"price = cheap", "price = cheap",
"day = tuesday", "day = tuesday",
@ -28159,7 +28073,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = dont care", "area = dont care",
"price = cheap", "price = cheap",
"day = tuesday", "day = tuesday",
@ -28185,7 +28098,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = dont care", "area = dont care",
"price = cheap", "price = cheap",
"day = tuesday", "day = tuesday",
@ -28847,7 +28759,6 @@
], ],
"belief_states": [ "belief_states": [
"area = dont care", "area = dont care",
"parking = none",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
"internet = yes", "internet = yes",
@ -28876,7 +28787,6 @@
], ],
"belief_states": [ "belief_states": [
"area = dont care", "area = dont care",
"parking = none",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
"internet = yes", "internet = yes",
@ -28907,7 +28817,6 @@
], ],
"belief_states": [ "belief_states": [
"area = dont care", "area = dont care",
"parking = none",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
"internet = yes", "internet = yes",
@ -31526,9 +31435,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 19:16", "leave = 19:16",
"destination = none", "day = tuesday"
"day = tuesday",
"departure = none"
] ]
}, },
{ {
@ -31616,9 +31523,7 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"food = none",
"price = expensive", "price = expensive",
"name = none",
"area = centre", "area = centre",
"day = monday", "day = monday",
"people = 7", "people = 7",
@ -32277,7 +32182,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"destination = none",
"day = monday", "day = monday",
"departure = ely" "departure = ely"
] ]
@ -32853,9 +32757,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = wandlebury country park"
"name = wandlebury country park",
"area = none"
] ]
}, },
{ {
@ -32870,9 +32772,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = wandlebury country park"
"name = wandlebury country park",
"area = none"
] ]
}, },
{ {
@ -35981,7 +35881,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"day = wednesday", "day = wednesday",
"arrive = 09:30" "arrive = 09:30"
] ]
@ -37841,8 +37740,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"parking = none",
"stars = none",
"type = hotel", "type = hotel",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
@ -37866,8 +37763,6 @@
], ],
"belief_states": [ "belief_states": [
"name = city centre north b and b", "name = city centre north b and b",
"parking = none",
"stars = none",
"type = hotel", "type = hotel",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
@ -37893,8 +37788,6 @@
], ],
"belief_states": [ "belief_states": [
"name = city centre north b and b", "name = city centre north b and b",
"parking = none",
"stars = none",
"type = hotel", "type = hotel",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
@ -37922,8 +37815,6 @@
], ],
"belief_states": [ "belief_states": [
"name = city centre north b and b", "name = city centre north b and b",
"parking = none",
"stars = none",
"type = hotel", "type = hotel",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
@ -38273,7 +38164,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = london kings cross", "destination = london kings cross",
"day = wednesday", "day = wednesday",
"arrive = 18:00" "arrive = 18:00"
@ -40012,7 +39902,6 @@
], ],
"belief_states": [ "belief_states": [
"name = warkworth house", "name = warkworth house",
"area = none",
"stars = 4", "stars = 4",
"day = wednesday", "day = wednesday",
"people = 4", "people = 4",
@ -40031,8 +39920,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = none",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
"day = wednesday", "day = wednesday",
@ -40054,7 +39941,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"area = none",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
"day = wednesday", "day = wednesday",
@ -40078,7 +39964,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"area = none",
"parking = dont care", "parking = dont care",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
@ -40105,7 +39990,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"area = none",
"parking = dont care", "parking = dont care",
"price = moderate", "price = moderate",
"stars = 4", "stars = 4",
@ -41515,7 +41399,6 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none",
"name = clare college", "name = clare college",
"area = west" "area = west"
] ]
@ -41543,8 +41426,7 @@
], ],
"belief_states": [ "belief_states": [
"food = asian oriental", "food = asian oriental",
"price = moderate", "price = moderate"
"area = none"
] ]
}, },
{ {
@ -41562,7 +41444,6 @@
"food = asian oriental", "food = asian oriental",
"price = moderate", "price = moderate",
"name = yippee noodle bar", "name = yippee noodle bar",
"area = none",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
"time = 11:45" "time = 11:45"
@ -41583,8 +41464,6 @@
], ],
"belief_states": [ "belief_states": [
"price = moderate", "price = moderate",
"name = none",
"area = none",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
"time = 11:45" "time = 11:45"
@ -41607,8 +41486,6 @@
], ],
"belief_states": [ "belief_states": [
"price = moderate", "price = moderate",
"name = none",
"area = none",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
"time = 11:45" "time = 11:45"
@ -41633,8 +41510,6 @@
], ],
"belief_states": [ "belief_states": [
"price = moderate", "price = moderate",
"name = none",
"area = none",
"day = sunday", "day = sunday",
"people = 4", "people = 4",
"time = 11:45" "time = 11:45"
@ -42213,8 +42088,7 @@
], ],
"belief_states": [ "belief_states": [
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -42230,8 +42104,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -42249,8 +42122,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -42270,8 +42142,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -42293,8 +42164,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -42318,8 +42188,7 @@
"leave = 17:00", "leave = 17:00",
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"departure = stansted airport", "departure = stansted airport"
"people = none"
] ]
}, },
{ {
@ -42564,7 +42433,6 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"food = none",
"price = expensive", "price = expensive",
"area = west" "area = west"
] ]
@ -43460,7 +43328,6 @@
], ],
"belief_states": [ "belief_states": [
"food = indian", "food = indian",
"name = none",
"area = west" "area = west"
] ]
}, },
@ -43477,7 +43344,6 @@
], ],
"belief_states": [ "belief_states": [
"food = indian", "food = indian",
"name = none",
"area = west" "area = west"
] ]
}, },
@ -43496,7 +43362,6 @@
], ],
"belief_states": [ "belief_states": [
"food = indian", "food = indian",
"name = none",
"area = west" "area = west"
] ]
}, },
@ -43824,7 +43689,6 @@
], ],
"belief_states": [ "belief_states": [
"destination = norwich", "destination = norwich",
"day = none",
"departure = cambridge" "departure = cambridge"
] ]
}, },
@ -44343,7 +44207,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"departure = thanh binh" "departure = thanh binh"
] ]
}, },
@ -44357,7 +44220,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = prezzo", "destination = prezzo",
"departure = thanh binh" "departure = thanh binh"
] ]
@ -44374,7 +44236,6 @@
"taxi" "taxi"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = prezzo", "destination = prezzo",
"departure = thanh binh" "departure = thanh binh"
] ]
@ -44613,7 +44474,6 @@
], ],
"belief_states": [ "belief_states": [
"price = cheap", "price = cheap",
"name = none",
"area = centre" "area = centre"
] ]
}, },
@ -44632,7 +44492,6 @@
], ],
"belief_states": [ "belief_states": [
"price = cheap", "price = cheap",
"name = none",
"area = centre" "area = centre"
] ]
}, },
@ -45938,15 +45797,7 @@
"domains": [ "domains": [
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": []
"parking = none",
"price = none",
"stars = none",
"internet = none",
"day = none",
"people = none",
"stay = none"
]
}, },
{ {
"history": [ "history": [
@ -45963,15 +45814,7 @@
"domains": [ "domains": [
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": []
"parking = none",
"price = none",
"stars = none",
"internet = none",
"day = none",
"people = none",
"stay = none"
]
}, },
{ {
"history": [ "history": [
@ -48155,7 +47998,6 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"food = none",
"name = tang chinese" "name = tang chinese"
] ]
}, },
@ -48169,7 +48011,6 @@
"restaurant" "restaurant"
], ],
"belief_states": [ "belief_states": [
"food = none",
"name = tang chinese", "name = tang chinese",
"day = friday", "day = friday",
"people = 4", "people = 4",
@ -48189,7 +48030,6 @@
], ],
"belief_states": [ "belief_states": [
"food = dont care", "food = dont care",
"price = none",
"area = centre", "area = centre",
"day = friday", "day = friday",
"people = 4", "people = 4",
@ -48211,7 +48051,6 @@
], ],
"belief_states": [ "belief_states": [
"food = dont care", "food = dont care",
"price = none",
"name = kymmoy", "name = kymmoy",
"area = centre", "area = centre",
"day = friday", "day = friday",
@ -48236,7 +48075,6 @@
], ],
"belief_states": [ "belief_states": [
"food = dont care", "food = dont care",
"price = none",
"name = kymmoy", "name = kymmoy",
"area = centre", "area = centre",
"day = friday", "day = friday",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1087,10 +1087,7 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = none",
"parking = yes", "parking = yes",
"price = none",
"type = guesthouse" "type = guesthouse"
] ]
}, },
@ -1105,9 +1102,7 @@
], ],
"belief_states": [ "belief_states": [
"name = acorn guest house", "name = acorn guest house",
"area = none",
"parking = yes", "parking = yes",
"price = none",
"type = guesthouse" "type = guesthouse"
] ]
}, },
@ -1124,9 +1119,7 @@
], ],
"belief_states": [ "belief_states": [
"name = acorn guest house", "name = acorn guest house",
"area = none",
"parking = yes", "parking = yes",
"price = none",
"type = guesthouse" "type = guesthouse"
] ]
}, },
@ -1145,9 +1138,7 @@
], ],
"belief_states": [ "belief_states": [
"name = acorn guest house", "name = acorn guest house",
"area = none",
"parking = yes", "parking = yes",
"price = none",
"type = guesthouse", "type = guesthouse",
"day = friday", "day = friday",
"people = 5", "people = 5",
@ -1171,9 +1162,7 @@
], ],
"belief_states": [ "belief_states": [
"name = acorn guest house", "name = acorn guest house",
"area = none",
"parking = yes", "parking = yes",
"price = none",
"type = guesthouse", "type = guesthouse",
"day = friday", "day = friday",
"people = 5", "people = 5",
@ -1466,8 +1455,7 @@
], ],
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"arrive = 20:30", "arrive = 20:30"
"people = none"
] ]
}, },
{ {
@ -1482,8 +1470,7 @@
"belief_states": [ "belief_states": [
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30"
"people = none"
] ]
}, },
{ {
@ -1501,8 +1488,7 @@
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30",
"departure = leicester", "departure = leicester"
"people = none"
] ]
}, },
{ {
@ -1523,8 +1509,7 @@
"destination = cambridge", "destination = cambridge",
"day = tuesday", "day = tuesday",
"arrive = 20:30", "arrive = 20:30",
"departure = leicester", "departure = leicester"
"people = none"
] ]
}, },
{ {
@ -3523,7 +3508,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"day = wednesday", "day = wednesday",
"arrive = 16:00" "arrive = 16:00"
] ]
@ -5809,7 +5793,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"destination = birmingham new street", "destination = birmingham new street",
"day = thursday", "day = thursday",
"arrive = 17:15", "arrive = 17:15",
@ -7807,8 +7790,7 @@
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"arrive = 21:00", "arrive = 21:00",
"departure = huntingdon", "departure = huntingdon"
"people = none"
] ]
}, },
{ {
@ -7828,8 +7810,7 @@
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"arrive = 21:00", "arrive = 21:00",
"departure = huntingdon", "departure = huntingdon"
"people = none"
] ]
}, },
{ {
@ -7851,8 +7832,7 @@
"destination = cambridge", "destination = cambridge",
"day = saturday", "day = saturday",
"arrive = 21:00", "arrive = 21:00",
"departure = huntingdon", "departure = huntingdon"
"people = none"
] ]
}, },
{ {
@ -8128,7 +8108,6 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"name = none",
"area = west", "area = west",
"type = guesthouse" "type = guesthouse"
] ]
@ -8229,12 +8208,7 @@
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": [
"parking = none", "price = cheap"
"price = cheap",
"type = none",
"day = none",
"people = none",
"stay = none"
] ]
}, },
{ {
@ -8260,11 +8234,7 @@
"name = worth house", "name = worth house",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
"stars = dont care", "stars = dont care"
"type = none",
"day = none",
"people = none",
"stay = none"
] ]
}, },
{ {
@ -8292,8 +8262,7 @@
"name = worth house", "name = worth house",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
"stars = dont care", "stars = dont care"
"type = none"
] ]
}, },
{ {
@ -8323,8 +8292,7 @@
"name = worth house", "name = worth house",
"parking = yes", "parking = yes",
"price = cheap", "price = cheap",
"stars = dont care", "stars = dont care"
"type = none"
] ]
}, },
{ {
@ -8488,9 +8456,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -8503,9 +8469,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -8520,9 +8484,7 @@
"attraction" "attraction"
], ],
"belief_states": [ "belief_states": [
"type = none", "name = kambar"
"name = kambar",
"area = none"
] ]
}, },
{ {
@ -10883,7 +10845,6 @@
], ],
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = none",
"departure = la margherita" "departure = la margherita"
] ]
}, },
@ -10898,9 +10859,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = none", "departure = la margherita"
"departure = la margherita",
"arrive = none"
] ]
}, },
{ {
@ -10916,9 +10875,7 @@
], ],
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = none", "departure = la margherita"
"departure = la margherita",
"arrive = none"
] ]
}, },
{ {
@ -10937,8 +10894,7 @@
"belief_states": [ "belief_states": [
"leave = 14:00", "leave = 14:00",
"destination = la margherita", "destination = la margherita",
"departure = avalon", "departure = avalon"
"arrive = none"
] ]
}, },
{ {
@ -12636,7 +12592,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"leave = none",
"arrive = 18:00", "arrive = 18:00",
"departure = cambridge" "departure = cambridge"
] ]
@ -13141,7 +13096,6 @@
"belief_states": [ "belief_states": [
"food = indian", "food = indian",
"price = cheap", "price = cheap",
"name = none",
"area = centre", "area = centre",
"day = friday", "day = friday",
"people = 8", "people = 8",
@ -13192,7 +13146,6 @@
"belief_states": [ "belief_states": [
"food = indian", "food = indian",
"price = cheap", "price = cheap",
"name = none",
"area = centre", "area = centre",
"day = friday", "day = friday",
"people = 8", "people = 8",
@ -13221,7 +13174,6 @@
"belief_states": [ "belief_states": [
"food = indian", "food = indian",
"price = cheap", "price = cheap",
"name = none",
"area = centre", "area = centre",
"day = friday", "day = friday",
"people = 8", "people = 8",
@ -13328,7 +13280,6 @@
"train" "train"
], ],
"belief_states": [ "belief_states": [
"destination = none",
"day = monday", "day = monday",
"departure = ely" "departure = ely"
] ]
@ -15244,8 +15195,7 @@
], ],
"belief_states": [ "belief_states": [
"food = portugese", "food = portugese",
"price = cheap", "price = cheap"
"area = none"
] ]
}, },
{ {
@ -18122,15 +18072,7 @@
"domains": [ "domains": [
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": []
"parking = none",
"price = none",
"stars = none",
"internet = none",
"day = none",
"people = none",
"stay = none"
]
}, },
{ {
"history": [ "history": [
@ -18147,15 +18089,7 @@
"domains": [ "domains": [
"hotel" "hotel"
], ],
"belief_states": [ "belief_states": []
"parking = none",
"price = none",
"stars = none",
"internet = none",
"day = none",
"people = none",
"stay = none"
]
}, },
{ {
"history": [ "history": [

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,637 @@
helpful
place
yeas
perfect
sounds
long
dont'
nope
then
?
hey
doesnt'
yeah
great
good
nice
instead
help
pre
fine
ly
!
,
.
's
goodbye
noise
ah
uh
um
ooh
er
im
iam
ill
id
bye
hm
loo
ne
mmhm
sil
unintelligible
cough
breath
yea
sigh
tvnoise
er
huh
a
a's
able
about
above
according
accordingly
across
actually
after
afterwards
again
against
ain't
all
allow
allows
almost
alone
along
already
also
although
always
am
among
amongst
an
and
another
any
anybody
anyhow
anyone
anything
anyway
anyways
anywhere
apart
appear
appreciate
appropriate
are
aren't
around
as
aside
ask
asking
associated
at
available
away
awfully
b
be
became
because
become
becomes
becoming
been
before
beforehand
behind
being
believe
below
beside
besides
best
better
between
beyond
both
brief
but
by
c
c'mon
c's
came
can
can't
cannot
cant
cause
causes
certain
certainly
changes
clearly
co
com
come
comes
concerning
consequently
consider
considering
contain
containing
contains
corresponding
could
couldn't
course
currently
d
definitely
described
despite
did
didn't
different
do
does
doesn't
doing
don't
done
down
downwards
during
e
each
edu
eg
eight
either
else
elsewhere
enough
entirely
especially
et
etc
even
ever
every
everybody
everyone
everything
everywhere
ex
exactly
example
except
f
far
few
fifth
first
five
followed
following
follows
for
former
formerly
forth
four
from
further
furthermore
g
get
gets
getting
given
gives
go
goes
going
gone
got
gotten
greetings
h
had
hadn't
happens
hardly
has
hasn't
have
haven't
having
he
he's
hello
help
hence
her
here
here's
hereafter
hereby
herein
hereupon
hers
herself
hi
him
himself
his
hither
hopefully
how
howbeit
however
i
i'd
i'll
i'm
i've
ie
if
ignored
immediate
in
inasmuch
inc
indeed
indicate
indicated
indicates
inner
insofar
instead
into
inward
is
isn't
it
it'd
it'll
it's
its
itself
j
just
k
keep
keeps
kept
know
known
knows
l
last
lately
later
latter
latterly
least
less
lest
let
let's
like
liked
likely
little
look
looking
looks
ltd
m
mainly
many
may
maybe
me
mean
meanwhile
merely
might
more
moreover
most
mostly
much
must
my
myself
n
name
namely
nd
near
nearly
necessary
need
needs
neither
never
nevertheless
new
next
night
nights
nine
no
nobody
non
none
noone
nor
normally
not
nothing
novel
now
nowhere
number
o
obviously
of
off
often
oh
ok
okay
old
on
once
one
ones
only
onto
or
other
others
otherwise
ought
our
ours
ourselves
out
outside
over
overall
own
p
particular
particularly
people
per
perhaps
placed
please
plus
possible
presumably
price
probably
provides
q
que
quite
qv
r
rather
rating
ratings
range
rd
re
really
reasonably
regarding
regardless
regards
relatively
respectively
right
s
said
same
saw
say
saying
says
second
secondly
see
seeing
seem
seemed
seeming
seems
seen
self
selves
sensible
sent
serious
seriously
seven
several
shall
she
should
shouldn't
since
six
so
some
somebody
somehow
someone
something
sometime
sometimes
somewhat
somewhere
soon
sorry
specified
specify
specifying
star
stars
still
sub
such
sup
sure
t
t's
take
taken
tell
tends
th
than
thank
thanks
thanx
that
that's
thats
the
their
theirs
them
themselves
t['originalText']hen
thence
there
there's
thereafter
thereby
therefore
therein
theres
thereupon
these
they
they'd
they'll
they're
they've
think
third
this
thorough
thoroughly
those
though
three
through
throughout
thru
thus
to
together
too
took
toward
towards
tried
tries
truly
try
trying
twice
two
u
un
under
unfortunately
unless
unlikely
until
unto
up
upon
us
use
used
useful
uses
using
usually
uucp
v
value
various
very
via
viz
vs
w
want
wants
was
wasn't
way
we
we'd
we'll
we're
we've
welcome
well
went
were
weren't
what
what's
whatever
when
whence
whenever
where
where's
whereafter
whereas
whereby
wherein
whereupon
wherever
whether
which
while
whither
who
who's
whoever
whole
whom
whose
why
will
willing
wish
with
within
without
won't
wonder
would
wouldn't
x
y
yes
yet
you
you'd
you'll
you're
you've
your
yours
yourself
yourselves
z
zero
hotel
train
restaurant
attraction
taxi
book

@ -0,0 +1,77 @@
import json
import os
import numpy as np
class PromptDstDataset:
def __init__(self, file_path, shuffle=True):
# Assertion check for file availability
assert os.path.isfile(file_path)
# add all processed data items to this list
self.dataset_items = []
self.total_num_slot_value_pairs = 0
print('Loading the dataset from :: ', file_path)
dataset_list = json.load(open(file_path))
for item in dataset_list:
history_str = '\n '.join(item['history'])
# fill this with dialog history and slot-value pairs
data_item = {'history': history_str + "\n"}
# add extracted values for text/valid datasets
if 'values' in item:
data_item['values'] = item['values']
belief_states = item['belief_states']
if len(belief_states) == 0:
continue
slot_value_list = []
for belief_state in belief_states:
# split 'slot = value' using '=' delimiter
slot_value_split = belief_state.split("=")
# check if the 'slot = value' item is valid
if len(slot_value_split) != 2:
continue
slot = slot_value_split[0].strip().lower()
value = slot_value_split[1].strip().lower()
# skip slot values with invalid data like "none"|"None"
if value == "none" or value == "None":
continue # don't add this (slot, value) pair - Invalid
if slot and value:
slot_value_pair = (slot, value)
slot_value_list.append(slot_value_pair)
# If (slot, value) pairs are empty, continue & don't add this item
if len(slot_value_list) <= 0:
continue
data_item['belief_states'] = slot_value_list
self.total_num_slot_value_pairs += len(slot_value_list)
# add the data_item dataset_items list
# this item should be returned via getitem function
self.dataset_items.append(data_item)
# shuffle the data items list
if shuffle:
np.random.shuffle(self.dataset_items)
# print some statistics
print('Total data items = {}, Total (slot, value) pairs = {}'
.format(len(self.dataset_items), self.total_num_slot_value_pairs))
def getitem(self, index):
return self.dataset_items[index]
def len(self):
return len(self.dataset_items)
def total_slot_value_pairs(self):
return self.total_num_slot_value_pairs

@ -0,0 +1,49 @@
import json
import os
class PromptDSTEvaluator:
def __init__(self, outputs_file_path=None):
self.true_states_list = []
self.gen_states_list = []
if outputs_file_path is not None and os.path.isfile(outputs_file_path):
outputs = json.load(open(outputs_file_path))
for item in outputs:
self.true_states_list.append(item['true_states'])
self.gen_states_list.append(item['gen_states'])
def add_data_item(self, true_states, gen_states):
self.true_states_list.append(true_states)
self.gen_states_list.append(gen_states)
def compute_joint_goal_accuracy(self, no_print=False):
if not no_print:
print('Computing Joint Goal Accuracy metric...')
if len(self.true_states_list) != len(self.gen_states_list):
raise ValueError('Length mismatch!')
# keep a count for computing JGA
correctly_predicted, total_turns = 0, 0
for truth, generated in zip(self.true_states_list, self.gen_states_list):
total_turns += 1
if set(truth.keys()) != set(generated.keys()):
continue
has_wrong_slot_value = False
for slot in truth:
if truth[slot] != generated[slot]:
has_wrong_slot_value = True
break
if not has_wrong_slot_value:
correctly_predicted += 1
jga_score = (correctly_predicted / total_turns) * 100
if not no_print:
print('Joint Goal Accuracy = ', jga_score)
return jga_score

@ -0,0 +1,135 @@
import argparse
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from metrics import PromptDSTEvaluator
from datetime import datetime
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
return generated_word.strip().lower()
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--test_data_file", default=None, type=str, required=True,
help="The test/eval data file <JSON Path>.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The directory where the predictions should be saved")
parser.add_argument("--tuned_model_path", default=None, type=str, required=True,
help="The fine-tuned model path")
# Optional
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
# parse the arguments
args = parser.parse_args()
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.tuned_model_path)
model = AutoModelForCausalLM.from_pretrained(args.tuned_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# set seed
set_seed(args)
# load testing/eval dataset
dataset = PromptDstDataset(args.test_data_file)
# set tqdm progress bars for Epochs & number of training steps
progress = tqdm(total=dataset.len(), desc="Progress")
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
# iterate through test dataset and generate slots
for item in dataset.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and add them to true states
for slot, value in item['belief_states']:
true_states[slot] = value
# iterate through (slot, value) pairs and generate each slot using value
for value in item['values']:
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"history": history,
"true_states": true_states,
"gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
# compute JGA & print results
evaluator.compute_joint_goal_accuracy()
# save the outputs to output_dir
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(args.output_dir, 'outputs-{}.json'.format(datetime_str))
print('Saving Outputs file :: ', output_file)
json.dump(outputs, open(output_file, 'w'), indent=2)
if __name__ == "__main__":
main()

@ -0,0 +1,295 @@
import argparse
import numpy as np
import os
import json
import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
from dataset import PromptDstDataset
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from prompt_utils import get_value_based_prompt
from prompt_utils import get_prompt_for_training
from prompt_utils import TYPE_VALUE_BASED_PROMPT
from prompt_utils import TYPE_INVERSE_PROMPT
from metrics import PromptDSTEvaluator
from datetime import datetime
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_data_file", default=None, type=str, required=True,
help="The input training data file <JSON Path>.")
parser.add_argument("--save_model_dir", default=None, type=str, required=True,
help="The directory where the model should be saved")
parser.add_argument("--pretrained_model_path", default=None, type=str, required=True,
help="The pre-trained model path for fine tuning [Either original SOLOIST "
"or a saved model checkpoint]")
# Optional
parser.add_argument("--num_epochs", default=5, type=int,
help="Total number of training epochs to perform.")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size for training.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam Optimizer.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight decay")
parser.add_argument("--with_inverse_prompt", action="store_true",
help="Flag for enabling/disabling inverse prompt during training")
parser.add_argument("--inverse_prompt_weight", default=0.1, type=float,
help="Weight to adjust the influence of Inverse Prompt, decimal (0,1)")
parser.add_argument("--validation_file", default="", type=str,
help="Validation file for evaluating model after each epoch")
# parse the arguments
args = parser.parse_args()
# setup CUDA device for training on GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
# prepare model & tokenizer -> load pre-trained model
tokenizer = GPT2Tokenizer.from_pretrained(args.pretrained_model_path, do_lower_case=True)
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_path, pad_token_id=tokenizer.eos_token_id)
# set the device to the model
model.to(device)
# set seed
set_seed(args)
# load training dataset
training_data = PromptDstDataset(args.train_data_file)
# load validation dataset
validation_data = None
if args.validation_file:
validation_data = PromptDstDataset(args.validation_file)
# create an optimizer and learning rate scheduler to fine-tune the model
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=(args.num_epochs * training_data.total_num_slot_value_pairs)
)
# set tqdm progress bars for Epochs & number of training steps
num_training_steps = args.num_epochs * training_data.len()
epochs = tqdm(total=args.num_epochs, desc="Epochs", position=0)
training_progress = tqdm(total=num_training_steps, desc="Training Progress", position=1)
# set the model in training mode
model.train()
# training loop
for epoch in range(args.num_epochs):
running_loss = 0.0
loss_count = 0
# set the model in training mode (after each epoch)
model.train()
for i, item in enumerate(training_data.dataset_items, start=1):
history = item['history']
# iterate through (slot, value) pairs
for slot, value in item['belief_states']:
# train/generate using value-based prompt first
loss, gen_slot = train_prompting(history=history,
slot_value_pair=(slot, value),
prompt_type=TYPE_VALUE_BASED_PROMPT,
tokenizer=tokenizer,
model=model,
device=device)
if args.with_inverse_prompt:
# use the generated slot from value-based prompt
# clean/process the generated slot (remove whitespaces & convert to lower case)
generated_slot = gen_slot.strip().lower()
# train slot generation using inverse prompt
inv_loss, _ = train_prompting(history=history,
slot_value_pair=(generated_slot, value),
prompt_type=TYPE_INVERSE_PROMPT,
tokenizer=tokenizer,
model=model,
device=device)
# compute total loss for this slot-value pair
loss = loss + (args.inverse_prompt_weight * inv_loss)
# store the loss for printing
running_loss += loss.item()
loss_count += 1
# backward pass & step
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# update progress
training_progress.update(1)
# print loss for every 100 steps
if i % 100 == 0:
last_loss = running_loss / loss_count
tqdm.write(str('Training Loss [Iteration {}, Epoch {}] = {}'.format(i, (epoch + 1), last_loss)))
running_loss = 0.0
loss_count = 0
# Save the model after finishing an epoch
output_dir = os.path.join(args.save_model_dir, '{}-{}'.format("epoch", "{:02d}".format(epoch + 1)))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
tqdm.write(str('Saving model (after Epoch {} ) to :: {}'.format((epoch + 1), output_dir)))
# update epoch progress
epochs.update(1)
# Epoch finished -> continue with validation if the validation file is provided
# if validation file is provided, run evaluation here (after each epoch)
if args.validation_file and validation_data is not None:
# set tqdm progress bars for testing progress
validation_progress = tqdm(total=validation_data.len(), desc="Validation", leave=False)
# set eval mode
model.eval()
# outputs array -> to be saved to output_dir
outputs = []
# JGA metric
evaluator = PromptDSTEvaluator()
# iterate through validation dataset and generate slots using value-based prompt
for item in validation_data.dataset_items:
history = item['history']
true_states = {}
gen_states = {}
# iterate through (slot, value) pairs and generate each slot using value
for slot, value in item['belief_states']:
true_states[slot] = value
# generate slot using value-based prompt
generated_slot = generate_slot_from_prompt(history=history,
value=value,
tokenizer=tokenizer,
model=model,
device=device)
# add the generated slot to generated states
gen_states[generated_slot] = value
# update tqdm progress
validation_progress.update(1)
# add true belief states & generated belief states to outputs
outputs.append({"true_states": true_states, "gen_states": gen_states})
# add true & generated belief states to evaluator for computing JGA
evaluator.add_data_item(true_states.copy(), gen_states.copy())
validation_progress.close()
# compute JGA & print results
tqdm.write(str('Computing Joint Goal Accuracy metric with TRUE values...'))
jga_score = evaluator.compute_joint_goal_accuracy(no_print=True)
tqdm.write(str('Joint Goal Accuracy(with True Values) [after Epoch-{}]: {}'.format((epoch + 1), jga_score)))
# save the outputs to trained epoch dir
now = datetime.now()
datetime_str = now.strftime("%Y%m%dT%H%M%S")
output_file = os.path.join(output_dir, 'outputs-{}.json'.format(datetime_str))
tqdm.write(str('Saving Validation Outputs file [after Epoch-{}] :: {}'.format((epoch + 1), output_file)))
json.dump(outputs, open(output_file, 'w'), indent=2)
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train_prompting(history, slot_value_pair, prompt_type, tokenizer, model, device):
# get prompt for training based on "type"
prompt = get_prompt_for_training(prompt_type, slot_value_pair)
# combine history and prompt
input_prompt = history + prompt
# encode the history & prompt using tokenizer
encoded_prompt = tokenizer(input_prompt, return_tensors="pt")
encoded_prompt.to(device)
# get the last token id
# this could be a slot or value depending on prompt type
last_token = encoded_prompt['input_ids'][:, -1:]
last_token.to(device)
# model outputs
outputs = model(**encoded_prompt)
# get last token logits [-2 -> for last but one item]
logits = outputs.logits[:, -2, :]
# softmax probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# last token generation probability
last_token_prob = torch.gather(probs, 1, last_token).squeeze(-1)
loss = torch.negative(torch.log(last_token_prob))
# generated word -> the one with the highest probability
generated_word = None
if prompt_type == TYPE_VALUE_BASED_PROMPT:
# find the token with the highest probability, this will be the generated word
gen_word_token = torch.argmax(logits, dim=-1)
generated_word = tokenizer.decode(gen_word_token, skip_special_tokens=True).strip()
# loss is the log of probability
return loss, generated_word
# Use this for generating next word (Validation after each epoch)
def generate_slot_from_prompt(history, value, tokenizer, model, device):
# get prompt for training based on "type"
prompt = get_value_based_prompt(value)
# combine history and prompt
prompt = history + prompt
# encode the history & prompt
encoded_prompt = tokenizer(prompt, return_tensors="pt")
encoded_prompt.to(device)
# generate 1 token
outputs = model.generate(**encoded_prompt, max_new_tokens=1)
gen_sequences = outputs[:, encoded_prompt['input_ids'].shape[-1]:]
generated_word = tokenizer.decode(gen_sequences.item(), skip_special_tokens=True)
return generated_word.strip().lower()
if __name__ == "__main__":
main()

@ -0,0 +1,30 @@
from string import Template
TYPE_VALUE_BASED_PROMPT = "value-based"
TYPE_INVERSE_PROMPT = "inverse-prompt"
PROMPT_TEMPLATES = {
"value-based": {
"training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot ="
},
"inverse-prompt": {
"training": "belief states: slot = $slot, value = $value",
"generate": "belief states: slot = $slot, value ="
}
}
def get_prompt_for_training(typ, slot_value):
template = Template(PROMPT_TEMPLATES[typ]["training"])
return template.substitute(slot=slot_value[0], value=slot_value[1])
def get_value_based_prompt(value):
template = Template(PROMPT_TEMPLATES["value-based"]["generate"])
return template.substitute(value=value)
def get_inverse_prompt(slot):
template = Template(PROMPT_TEMPLATES["inverse-prompt"]["generate"])
return template.substitute(slot=slot)

@ -0,0 +1,54 @@
#!/bin/bash
usage="$(basename "$0") [-m <fine-tuned-model-path>]
Argument -m takes the relative path of fine-tuned model from ${SAVED_MODELS_PROMPT}.
Example: -m 250-dpd/experiment-20221030T172424/epoch-08"
while getopts :m: flag
do
case "${flag}" in
m) model_path=${OPTARG};;
:) printf "missing argument for -%s\n" "$OPTARG" >&2; echo "$usage" >&2; exit 1;;
esac
done
# check for mandatory/required -m argument
# mandatory arguments
if [ ! "$model_path" ]; then
echo "arguments -m must be provided"
echo "$usage" >&2; exit 1
fi
# Check whether the required environment vars are set
if [ -z "${SAVED_MODELS_PROMPT}" ]; then
echo "Must set SAVED_MODELS_PROMPT in environment, run set_env.sh first!";
exit 1
fi
# Check whether the required environment vars are set
if [ -z "${OUTPUTS_DIR_PROMPT}" ]; then
echo "Must set OUTPUTS_DIR_PROMPT in environment, run set_env.sh first!";
exit 1
fi
# check if the training data file exists
TEST_DATA_FILE=../data/prompt-learning/test/test.soloist.json
if [ ! -f "${TEST_DATA_FILE}" ]; then
echo "Test/Valid Data file does not exist!"
exit 1
fi
FINE_TUNED_MODEL_PATH=${SAVED_MODELS_PROMPT}/${model_path}
if [ ! -d ${FINE_TUNED_MODEL_PATH} ]; then
echo "Invalid fine-tuned model path - ${model_path}"
fi
OUTPUTS_DIR=${OUTPUTS_DIR_PROMPT}/${model_path}
# create the dirs if not exist
mkdir -p "${OUTPUTS_DIR}"
python prompt_decode.py \
--output_dir="${OUTPUTS_DIR}" \
--tuned_model_path="${FINE_TUNED_MODEL_PATH}" \
--test_data_file="${TEST_DATA_FILE}"

@ -0,0 +1,59 @@
#!/bin/bash
usage="$(basename "$0") [-d <data-split-name>]
Argument -d takes (few-shot) data split names.
Possible data-split names : 50-dpd|100-dpd|125-dpd|250-dpd"
while getopts :d: flag
do
case "${flag}" in
d) data_split=${OPTARG};;
:) printf "missing argument for -%s\n" "$OPTARG" >&2; echo "$usage" >&2; exit 1;;
esac
done
# check for mandatory/required -d argument
# mandatory arguments
if [ ! "$data_split" ]; then
echo "arguments -d must be provided"
echo "$usage" >&2; exit 1
fi
# Check whether the required environment vars are set
if [ -z "${SAVED_MODELS_PROMPT}" ]; then
echo "Must set SAVED_MODELS_PROMPT in environment, run set_env.sh first!";
exit 1
fi
# Check whether the required environment vars are set
if [ -z "${PRE_TRAINED_SOLOIST}" ]; then
echo "Pre-trained SOLOIST Model path must be provided!";
echo "Must set PRE_TRAINED_SOLOIST in environment, run set_env.sh first!";
exit 1
fi
# check if the training data file exists
TRAIN_DATA_FILE=../data/prompt-learning/"${data_split}"/train.soloist.json
if [ -f "$TRAIN_DATA_FILE" ]; then
echo "Selected Training set :: ${data_split}/train.soloist.json"
else
echo "Training File with set ${data_split} does not exist."
exit 1
fi
# create experiment folder for storing saved models
datetime_now=$(date +"%Y%m%dT%H%M%S")
experiment_folder="${data_split}"/experiment-${datetime_now}
SAVE_DIR="${SAVED_MODELS_PROMPT}"/"${experiment_folder}"
echo "Trained Models (checkpoints/epochs) are saved in ${SAVE_DIR}"
python prompt_train.py \
--save_model_dir="${SAVE_DIR}" \
--pretrained_model_path="${PRE_TRAINED_SOLOIST}" \
--train_data_file="${TRAIN_DATA_FILE}" \
--validation_file=../data/prompt-learning/valid/valid.soloist.json \
--num_epochs 10 \
--learning_rate 5e-5 \
--with_inverse_prompt \
--inverse_prompt_weight 0.1

@ -7,7 +7,14 @@ export OUTPUTS_DIR_BASELINE=/mount/studenten/projects/mandavsi/baseline/outputs
# path of pretrained SOLOIST model # path of pretrained SOLOIST model
export PRE_TRAINED_SOLOIST=/mount/studenten/projects/mandavsi/soloist-pretrained/gtg_pretrained export PRE_TRAINED_SOLOIST=/mount/studenten/projects/mandavsi/soloist-pretrained/gtg_pretrained
# Path for storing prompt based models
export SAVED_MODELS_PROMPT=/mount/studenten/projects/mandavsi/prompt-learning/trained-models
# path for storing test outputs
export OUTPUTS_DIR_PROMPT=/mount/studenten/projects/mandavsi/prompt-learning/outputs
# create dirs if not exist # create dirs if not exist
mkdir -p ${SAVED_MODELS_BASELINE} mkdir -p ${SAVED_MODELS_BASELINE}
mkdir -p ${OUTPUTS_DIR_BASELINE} mkdir -p ${OUTPUTS_DIR_BASELINE}
mkdir -p ${PRE_TRAINED_SOLOIST} mkdir -p ${PRE_TRAINED_SOLOIST}
mkdir -p ${SAVED_MODELS_PROMPT}
mkdir -p ${OUTPUTS_DIR_PROMPT}

@ -0,0 +1,260 @@
import json
import re
import string
corenlp_props = {
'annotators': 'tokenize, pos, ner, dcoref',
'pipelineLanguage': 'en',
'outputFormat': 'json',
'parse.maxlen': '1000',
'timeout': '500000'
}
STOPWORDS_FILE = "../data/resource/stopwords.txt"
DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi", "book"]
SLOTS = {'area', 'arrive', 'day', 'departure', 'destination', 'food', 'internet', 'leave',
'name', 'parking', 'people', 'price', 'stars', 'stay', 'time', 'type'}
VALUES_CONVERT = {
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'wifi': 'internet',
'wlan': 'internet',
'wi-fi': 'internet',
'moderately': 'moderate',
}
def bad_entity(text):
if text == "this":
return True
if text == "that":
return True
if text == "there":
return True
if text == "here":
return True
if text == "|":
return True
if text == "less":
return True
if text == "more":
return True
return False
def fix_stanford_coref(stanford_json):
true_corefs = {}
# get a chain
for key, coref in stanford_json["corefs"].items():
true_coref = []
# get an entity mention
for entity in coref:
sent_num = entity["sentNum"] - 1 # starting from 0
start_index = entity["startIndex"] - 1 # starting from 0
end_index = entity["endIndex"] - 1 # starting from 0
head_index = entity["headIndex"] - 1 # starting from 0
entity_label = stanford_json["sentences"][
sent_num]["tokens"][head_index]["ner"]
entity["sentNum"] = sent_num
entity["startIndex"] = start_index
entity["endIndex"] = end_index
entity["headIndex"] = head_index
entity["headWord"] = entity["text"].split(
" ")[head_index - start_index]
entity["entityType"] = entity_label
true_coref.append(entity)
# check link is not empty
if len(true_coref) > 0:
no_representative = True
has_representative = False
for idx, entity in enumerate(true_coref):
if entity["isRepresentativeMention"]:
if not (entity["type"] == "PRONOMINAL" or
bad_entity(entity["text"].lower()) or
len(entity["text"].split(" ")) > 10):
no_representative = False
has_representative = True
# remove bad representative assignments
else:
true_coref[idx]["isRepresentativeMention"] = False
# check there exists one representative mention
if no_representative:
for idx, entity in enumerate(true_coref):
if not (entity["type"] == "PRONOMINAL" or
bad_entity(entity["text"].lower()) or
len(entity["text"].split(" ")) > 10):
true_coref[idx]["isRepresentativeMention"] = True
has_representative = True
if has_representative:
true_corefs[key] = true_coref
return true_corefs
def clean(corefs: list, stopwords_list: list):
dup_ids = []
for i, coref1 in enumerate(corefs):
consist_num = 0
short = []
for j, coref2 in enumerate(corefs):
if coref1[2][0] <= coref2[2][0] and coref1[2][1] >= coref2[2][1] and (not i == j):
consist_num += 1
short.append(j)
if consist_num > 1:
dup_ids.append(i)
elif consist_num == 1:
dup_ids.extend(short)
corefs = [corefs[i] for i in range(len(corefs)) if i not in dup_ids]
temp = []
for coref in corefs:
seq = coref[-1].split()
while seq and (seq[0] in stopwords_list or seq[-1] in stopwords_list):
if seq[0] in stopwords_list:
del seq[0]
if seq[-1] in stopwords_list:
del seq[-1]
if not seq:
temp.append(coref)
else:
coref[-1] = ' '.join(seq)
for t in temp:
corefs.remove(t)
return corefs
def get_candidates(user_annotation, stopwords_list):
"""Candidates include adjs, entities and corefs."""
tokens = []
candidates = {}
entities = []
postags = []
corefs = []
base_index = [0]
read_annotation(user_annotation, base_index, stopwords_list, tokens, entities, postags, corefs, 0)
candidates['postag'] = postags
candidates['coref'] = clean(corefs, stopwords_list)
candidates['coref'].extend(entities)
return candidates
def is_stop(text: str, stopwords_list: list):
text = list(filter(lambda x: x.lower() not in stopwords_list, text.split()))
if text:
return True
else:
return False
def read_annotation(annotation, base_index, stopwords_list, tokens, entities, postags, corefs, num_sen):
sentences = annotation["sentences"]
for i, sentence in enumerate(sentences):
for entity in sentence['entitymentions']:
head_idx = base_index[i + num_sen] + entity['tokenBegin']
head = sentence['tokens'][entity['tokenBegin']]['originalText']
mention = entity['text']
mention_start_idx = base_index[i + num_sen] + entity['tokenBegin']
mention_end_idx = base_index[i + num_sen] + entity['tokenEnd']
mention_idx = [mention_start_idx, mention_end_idx]
entities.append([head_idx, head, mention_idx, mention])
for j, token in enumerate(sentence['tokens']):
tokens.append(token['word'])
pos = token['pos']
lemma = token['lemma']
text = token['originalText']
if pos in ['JJ', 'RB']:
try:
prev = sentence['tokens'][j - 1]['originalText']
except IndexError:
prev = ''
if (not re.search(r"([a-z]\.[a-z])", lemma)) \
and lemma not in stopwords_list and prev != 'not':
head_idx = base_index[i + num_sen] + token['index'] - 1
postags.append([head_idx, text])
base_index.append(base_index[-1] + len(sentence['tokens']))
for coref in annotation['corefs'].values():
for realization in coref:
sent_num = realization['sentNum']
head_index = realization['headIndex']
head_idx = base_index[sent_num + num_sen] + head_index
head = sentences[sent_num]['tokens'][head_index]['originalText']
text_start_index = realization['startIndex']
text_start_idx = base_index[sent_num + num_sen] + text_start_index
text_end_index = realization['endIndex']
text_end_idx = base_index[sent_num + num_sen] + text_end_index
text_lemma = sentences[sent_num]['tokens'][text_start_index:text_end_index]
text_lemma = ' '.join(list(map(lambda x: x['originalText'], text_lemma)))
try:
prev1 = sentences[sent_num]['tokens'][text_start_index - 1]['originalText']
prev2 = sentences[sent_num]['tokens'][text_start_index - 2]['originalText']
except BaseException:
prev1 = ''
prev2 = ''
should_stop = is_stop(text_lemma, stopwords_list)
if should_stop and prev1 != 'not' and prev2 != 'not':
corefs.append([head_idx, head, [text_start_idx, text_end_idx], text_lemma])
def get_value_candidates_from_history(corenlp, history):
if len(history) == 0:
return []
stopwords = []
with open(STOPWORDS_FILE, 'r') as fin:
for line in fin:
stopwords.append(line.strip())
value_candidates = set()
user_utterance = ' '.join(utterance[len('user :'):] for utterance in history if utterance.startswith('user :'))
annotation = json.loads(corenlp.annotate(user_utterance, properties=corenlp_props))
annotation['corefs'] = fix_stanford_coref(annotation)
candidates = get_candidates(annotation, stopwords)
for _, candidate in candidates.items():
for c in candidate:
if len(c) == 2:
value_candidates.add(c[1].strip().lower())
else:
if len(c[3].split()) > 5:
value_candidates.add(c[1].strip().lower())
else:
value_candidates.add(c[3].strip().lower())
# clean value candidates
values = set()
for value in value_candidates:
if value in VALUES_CONVERT:
value = VALUES_CONVERT[value]
if value not in DOMAINS \
and value not in SLOTS \
and value not in string.punctuation \
and value not in stopwords \
and not value.startswith("'"):
# remove spaces before punctuation
value = re.sub(r"\s+([?.!'])", r"\1", value).strip()
if value and value[0].isdigit():
# remove everything after end of a number
value = re.sub(r'\D+$', '', value)
if value.strip() and len(value.split()) <= 4:
values.add(value.strip())
return list(values)

@ -1,12 +1,13 @@
import collections
import json import json
import os import os
from pathlib import Path from pathlib import Path
from stanfordcorenlp import StanfordCoreNLP
from corenlp import get_value_candidates_from_history
from collections import Counter
BELIEF_PREFIX = 'belief :' BELIEF_PREFIX = 'belief :'
CORENLP_PATH = "../utils/stanford-corenlp"
ALL_SLOTS = ['area', 'arriveby', 'day', 'departure', 'destination',
'food', 'internet', 'leaveat', 'name', 'parking',
'people', 'pricerange', 'stars', 'time', 'type']
MODIFIED_SLOTS = { MODIFIED_SLOTS = {
'area': 'area', 'area': 'area',
@ -27,6 +28,8 @@ MODIFIED_SLOTS = {
'type': 'type' 'type': 'type'
} }
max_len = 0
def convert_slot_for_prompting(slot_value_item): def convert_slot_for_prompting(slot_value_item):
# check if the 'slot = value' item is valid # check if the 'slot = value' item is valid
@ -38,6 +41,10 @@ def convert_slot_for_prompting(slot_value_item):
slot = slot_value_item.split('=')[0].strip() slot = slot_value_item.split('=')[0].strip()
value = slot_value_item.split('=')[1].strip() value = slot_value_item.split('=')[1].strip()
# skip invalid slot values
if value.lower() == 'none':
return ''
# modify the slot for prompting # modify the slot for prompting
modified_slot = MODIFIED_SLOTS[slot] modified_slot = MODIFIED_SLOTS[slot]
@ -57,6 +64,12 @@ def create_belief_states_data_for_prompt_learning(data_tuple):
if len(data) <= 0: if len(data) <= 0:
return return
nlp = None
# start the CoreNLP server for Value Extraction
# Only required for test/valid dataset
if data_tuple[1] in ['test', 'valid']:
nlp = StanfordCoreNLP(CORENLP_PATH, memory='8g')
# data to be saved for prompt learning # data to be saved for prompt learning
belief_states_dataset = [] belief_states_dataset = []
@ -67,6 +80,10 @@ def create_belief_states_data_for_prompt_learning(data_tuple):
belief_states_data_item['history'] = item['history'] belief_states_data_item['history'] = item['history']
belief_states_data_item['domains'] = item['domains'] belief_states_data_item['domains'] = item['domains']
# extract value candidates using stanford CoreNLP & add to test/valid dataset
if data_tuple[1] in ['test', 'valid']:
belief_states_data_item['values'] = get_value_candidates_from_history(nlp, item['history'])
# extract belief states # extract belief states
belief_states = item['belief'] belief_states = item['belief']
@ -113,6 +130,8 @@ def create_belief_states_data_for_prompt_learning(data_tuple):
# add to the dataset (to be saved!) # add to the dataset (to be saved!)
belief_states_dataset.append(belief_states_data_item) belief_states_dataset.append(belief_states_data_item)
if data_tuple[1] in ['test', 'valid'] and nlp is not None:
nlp.close()
# save the dataset file # save the dataset file
save_file_path = '../data/prompt-learning/' + data_tuple[2] + '/' save_file_path = '../data/prompt-learning/' + data_tuple[2] + '/'
save_file_name = data_tuple[1] + '.soloist.json' save_file_name = data_tuple[1] + '.soloist.json'
Loading…
Cancel
Save