Skip to content

Training models

This section describes how to train models in the AI & Analytics Engine using the SDK.

When models are trained in the same app, all of them use the same train-test split generated when the app was created. This ensures the models can be compared fairly and scientifically.

Note: The flow of training a model in the UI is in the opposite order to the flow in the API via code. While in the API via code you first utilize the model recommendation and then create the models for the app. In the UI, you first start with creating a model and then you arrive to the model recommendation.

Train a model using GridSearch with default hyperparameters

from aiaengine import *

# create a new project
org_id = 'b6240512-cd17-43a0-8297-84c51c1bc5a0' # replace with your org ID
org = Org(org_id)
project = org.create_project(name="Demo project using Python SDK", description="Your demo project")

# create a new dataset
data_file = 'examples/datasets/german-credit.csv'
dataset = project.create_dataset(
    name=f"German Credit Data",
    data_source=FileSource(
        file_urls=[data_file],
        schema=[
            Column('checking_status', DataType.Text),
            Column('duration', DataType.Numeric),
            Column('credit_history', DataType.Text),
            Column('purpose', DataType.Text),
            Column('credit_amount', DataType.Numeric),
            Column('savings_status', DataType.Text),
            Column('employment', DataType.Text),
            Column('installment_commitment', DataType.Numeric),
            Column('personal_status', DataType.Text),
            Column('other_parties', DataType.Text),
            Column('residence_since', DataType.Numeric),
            Column('property_magnitude', DataType.Text),
            Column('age', DataType.Numeric),
            Column('other_payment_plans', DataType.Text),
            Column('housing', DataType.Text),
            Column('existing_credits', DataType.Numeric),
            Column('job', DataType.Text),
            Column('num_dependents', DataType.Numeric),
            Column('own_telephone', DataType.Text),
            Column('foreign_worker', DataType.Text),
            Column('class', DataType.Text)
        ]
    )
)

# create a new app
app = project.create_app(
    name=f"German Credit Risk Prediction Task",
    dataset_id=dataset.id,
    config=ClassificationConfig(
        sub_type=ClassificationSubType.BINARY,
        target_column="class",
        positive_class_label="good",
        negative_class_label="bad"
    )
)

# use the recommended feature set
feature_set = app.get_recommended_feature_set()

## selecting recommended models
# Once the app is processed successfully, model recommendations are provided 
# with predicted performance over a range of metrics such as accuracy and
# F1-macro score (for classification), as well as estimated time cost in
# training and prediction. In this example, we select the top 5 models based on
# F1-macro score.
print(feature_set.select_recommended_models(n=5, by_metric='f1_macro'))

# train a mode using the feature set and the XGBoost algorithm
model = app.create_model(
    name="XGBoost Classifier",
    template_id=Classifiers.XGBoost,
    feature_set_id=feature_set.id
)
package com.aiaengine.examples.model;

import com.aiaengine.*;
import com.aiaengine.app.ClassificationConfig;
import com.aiaengine.app.request.CreateModelRequest;
import com.aiaengine.datasource.DataSource;
import com.aiaengine.datasource.Schema;
import com.aiaengine.datasource.file.CSVFileSettings;
import com.aiaengine.datasource.file.FileSourceRequest;
import com.aiaengine.datasource.file.FileType;
import com.aiaengine.project.request.CreateAppRequest;
import com.aiaengine.project.request.CreateDatasetRequest;

import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;

public class TrainModelApp {
    public static void main(String[] args) throws FileNotFoundException {
        Engine engine = new Engine();
        // create a new demo project in the org
        Org org = engine.getOrg("cae24b10-e6b0-4d61-8cef-a9f4b8f6133d"); // replace with your org ID
//        Project project = org.createProject(CreateProjectRequest.builder()
//                .name("Demo project using Java SDK")
//                .description("Your demo project")
//                .build());
        // or you can get an existing project that you want to work on
        // Project project = engine.getProject("ID_of_your_project") // replace with your own project ID
        Project project = engine.getProject("c6e4589e-9cf4-4191-a85b-6eaf5fa80bf5");

        // import the `German Credit Data` dataset
        String dataFilePath = "examples/datasets/german-credit.csv";
        List<Schema.Column> columns = new ArrayList<>();
        columns.add(new Schema.Column("checking_status", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("duration", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("credit_history", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("purpose", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("credit_amount", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("savings_status", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("employment", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("installment_commitment", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("personal_status", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("other_parties", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("residence_since", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("property_magnitude", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("age", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("other_payment_plans", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("housing", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("existing_credits", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("job", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("num_dependents", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("own_telephone", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("foreign_worker", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("class", Schema.SemanticType.TEXT));
        DataSource localDataSource = engine.buildFileSource(FileSourceRequest.builder()
                .fileType(FileType.CSV)
                .url(dataFilePath)
                .fileSettings(new CSVFileSettings())
                .schema(new Schema(columns))
                .build());

        Dataset dataset = project.createDataset(CreateDatasetRequest.builder()
                .name("German Credit Data")
                .dataSource(localDataSource)
                .timeout(900)
                .build());

        App app = project.createApp(CreateAppRequest.builder()
                .name("Predict Customer Credit - Binary Classification")
                .datasetId(dataset.getId())
                .config(new ClassificationConfig("class",
                        ClassificationConfig.ClassificationSubType.BINARY,
                        "good", "bad"))
                .build());

        //use recommended featureset
        FeatureSet featureSet = app.getRecommendedFeatureSet(600);
        //train model with default hyperparameters
        app.createModel(CreateModelRequest.builder()
                        .name("XGBoost Classifier")
                        .featureSetId(featureSet.getId())
                        .templateId("xgboosting_clf")
                .build());
    }
}

Train a model using GridSearch with provided hyperparameters

from aiaengine import *

# create a new project
org_id = 'b6240512-cd17-43a0-8297-84c51c1bc5a0' # replace with your org ID
org = Org(org_id)
project = org.create_project(name="Demo project using Python SDK", description="Your demo project")

# create a new dataset
data_file = 'examples/datasets/german-credit.csv'
dataset = project.create_dataset(
    name=f"German Credit Data",
    data_source=FileSource(
        file_urls=[data_file],
        schema=[
            Column('checking_status', DataType.Text),
            Column('duration', DataType.Numeric),
            Column('credit_history', DataType.Text),
            Column('purpose', DataType.Text),
            Column('credit_amount', DataType.Numeric),
            Column('savings_status', DataType.Text),
            Column('employment', DataType.Text),
            Column('installment_commitment', DataType.Numeric),
            Column('personal_status', DataType.Text),
            Column('other_parties', DataType.Text),
            Column('residence_since', DataType.Numeric),
            Column('property_magnitude', DataType.Text),
            Column('age', DataType.Numeric),
            Column('other_payment_plans', DataType.Text),
            Column('housing', DataType.Text),
            Column('existing_credits', DataType.Numeric),
            Column('job', DataType.Text),
            Column('num_dependents', DataType.Numeric),
            Column('own_telephone', DataType.Text),
            Column('foreign_worker', DataType.Text),
            Column('class', DataType.Text)
        ]
    )
)

# create a new app
app = project.create_app(
    name=f"German Credit Risk Prediction Task",
    dataset_id=dataset.id,
    config=ClassificationConfig(
        sub_type=ClassificationSubType.BINARY,
        target_column="class",
        positive_class_label="good",
        negative_class_label="bad"
    )
)

# use the recommended feature set
feature_set = app.get_recommended_feature_set()

# train a mode using your own hyperparameters grid
model = app.create_model(
    name="XGBoost Classifier",
    template_id=Classifiers.XGBoost,
    feature_set_id=feature_set.id,
    hyperparameters_tuning_method=HyperparameterTuningMethod.GRID_SEARCH,
    hyperparameter_tuning_config={},
    hyperparameters={
        "learning_rate": [
            0.001,
            0.01,
            0.1,
            1
        ],
        "max_depth": [
            2,
            4,
            8
        ],
        "reg_lambda": [
            0.001,
            0.1,
            1,
            10
        ]
    })
package com.aiaengine.examples.model;

import com.aiaengine.*;
import com.aiaengine.app.ClassificationConfig;
import com.aiaengine.app.request.CreateModelRequest;
import com.aiaengine.datasource.DataSource;
import com.aiaengine.datasource.Schema;
import com.aiaengine.datasource.file.CSVFileSettings;
import com.aiaengine.datasource.file.FileSourceRequest;
import com.aiaengine.datasource.file.FileType;
import com.aiaengine.model.HyperparameterTuningMethod;
import com.aiaengine.project.request.CreateAppRequest;
import com.aiaengine.project.request.CreateDatasetRequest;

import java.io.FileNotFoundException;
import java.util.*;

public class TrainModelProvideHyperparametersApp {
    public static void main(String[] args) throws FileNotFoundException {
        Engine engine = new Engine();
        // create a new demo project in the org
        Org org = engine.getOrg("cae24b10-e6b0-4d61-8cef-a9f4b8f6133d"); // replace with your org ID
//        Project project = org.createProject(CreateProjectRequest.builder()
//                .name("Demo project using Java SDK")
//                .description("Your demo project")
//                .build());
        // or you can get an existing project that you want to work on
        // Project project = engine.getProject("ID_of_your_project") // replace with your own project ID
        Project project = engine.getProject("c6e4589e-9cf4-4191-a85b-6eaf5fa80bf5");

        // import the `German Credit Data` dataset
        String dataFilePath = "examples/datasets/german-credit.csv";
        List<Schema.Column> columns = new ArrayList<>();
        columns.add(new Schema.Column("checking_status", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("duration", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("credit_history", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("purpose", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("credit_amount", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("savings_status", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("employment", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("installment_commitment", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("personal_status", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("other_parties", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("residence_since", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("property_magnitude", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("age", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("other_payment_plans", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("housing", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("existing_credits", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("job", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("num_dependents", Schema.SemanticType.NUMERIC));
        columns.add(new Schema.Column("own_telephone", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("foreign_worker", Schema.SemanticType.TEXT));
        columns.add(new Schema.Column("class", Schema.SemanticType.TEXT));
        DataSource localDataSource = engine.buildFileSource(FileSourceRequest.builder()
                .fileType(FileType.CSV)
                .url(dataFilePath)
                .fileSettings(new CSVFileSettings())
                .schema(new Schema(columns))
                .build());

        Dataset dataset = project.createDataset(CreateDatasetRequest.builder()
                .name("German Credit Data")
                .dataSource(localDataSource)
                .timeout(900)
                .build());

        App app = project.createApp(CreateAppRequest.builder()
                .name("Predict Customer Credit - Binary Classification")
                .datasetId(dataset.getId())
                .config(new ClassificationConfig("class",
                        ClassificationConfig.ClassificationSubType.BINARY,
                        "good", "bad"))
                .build());

        //use recommended featureset
        FeatureSet featureSet = app.getRecommendedFeatureSet(600);
        //train model with provided hyperparameters
        Map<String, Object> parameters = new HashMap<>();
        parameters.put("learning_rate", Arrays.asList(0.001, 0.01, 0.1, 1));
        parameters.put("max_depth", Arrays.asList(2, 4, 6));
        parameters.put("reg_lambda", Arrays.asList(0.001, 0.1, 1, 10));
        app.createModel(CreateModelRequest.builder()
                        .name("XGBoost Classifier")
                        .featureSetId(featureSet.getId())
                        .templateId("xgboosting_clf")
                        .hyperparametersTuningMethod(HyperparameterTuningMethod.GRID)
                        .hyperparameters(parameters)
                .build());
    }
}