You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

17 lines
395 B

from .kronos import KronosTokenizer, Kronos, KronosPredictor
model_dict = {
'kronos_tokenizer': KronosTokenizer,
'kronos': Kronos,
'kronos_predictor': KronosPredictor
}
def get_model_class(model_name):
if model_name in model_dict:
return model_dict[model_name]
else:
print(f"Model {model_name} not found in model_dict")
raise NotImplementedError