You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.0 KiB

from string import Template
TYPE_VALUE_BASED_PROMPT = "value-based"
TYPE_INVERSE_PROMPT = "inverse-prompt"
TYPE_PROMPT_ENSEMBLE = "prompt-ensemble"
PROMPT_WEIGHT = 0.25
INVERSE_PROMPTS = {
"i1": "belief states: $slot = $value",
"i2": "belief states: slot = $slot, value = $value",
}
PROMPT_TEMPLATES = {
"value-based": {
"training": "belief states: value = $value, slot = $slot",
"generate": "belief states: value = $value, slot ="
},
"inverse-prompt": {
"training": INVERSE_PROMPTS["i2"],
},
"prompt-ensemble": {
"training": {
"p1": "belief states: value = $value, slot = $slot",
"p2": "belief states: $value = $slot",
"p3": "$value is of slot type $slot",
"p4": "$value is the value of $slot"
},
"generate": {
"p1": "belief states: value = $value, slot =",
"p2": "belief states: $value =",
"p3": "$value is of slot type",
"p4": "$value is the value of"
}
}
}
def get_prompt_for_training(typ, slot_value):
template = PROMPT_TEMPLATES[typ]['training']
if isinstance(template, str):
return Template(template).substitute(slot=slot_value[0], value=slot_value[1])
elif isinstance(template, dict):
template_list = template.values()
prompt_list = []
for template_str in template_list:
prompt = Template(template_str).substitute(slot=slot_value[0], value=slot_value[1])
prompt_list.append(prompt)
return prompt_list
def get_value_based_prompt(value):
template = Template(PROMPT_TEMPLATES[TYPE_VALUE_BASED_PROMPT]["generate"])
return template.substitute(value=value)
def get_ensemble_prompts(value):
templates = PROMPT_TEMPLATES[TYPE_PROMPT_ENSEMBLE]['generate']
template_list = templates.values()
prompt_list = []
for template_str in template_list:
template = Template(template_str)
prompt_list.append(template.substitute(value=value))
return prompt_list