About

29 янв. 2023 г.

Запуск Apache Ignite ML: мой опыт


Вам захотелось поковыряться в machine learning, но вы не питонист. Что делать?! Нет, не изучать питон, когда есть много интересного и в любимом языке. 

А т.к. меня особо привлекают малоизвестные проекты, наткнувшить на статью на хабре про ml библиотеку Apache Ignite ML, которая не на слуху, не мог пройти мимо. Вообще библиотека Apache Ignite ML одна из многих частей Apache Ignite - кластера для распределенных вычислений.

Итак, на старте. Интерес к ml и желание написать свой пример классификации и поэкспериментировать с точнотью распознавания. В качестве датасета тот самый банальный MNIST dataset, который несложно найти уже конвертированным в csv. В качестве основы пример из оф.репозитрия Apache Ignite классификации OneVsRestClassification алгоритмом SVM. Запускаю на одном своем компе, у которого 16г оперативки. Получившийся код на гитхабе.

Начало.

Создаю новый проект, добавляю либы игнайта: 

<properties>
...
<ignite.version>2.14.0</ignite.version>
</properties>

<dependencies>
<dependency>
<groupId>org.apache.ignite</groupId>
<artifactId>ignite-core</artifactId>
<version>${ignite.version}</version>
</dependency>
<dependency>
<groupId>org.apache.ignite</groupId>
<artifactId>ignite-ml</artifactId>
<version>${ignite.version}</version>
</dependency>
</dependencies>

Конфигурацию запуска ignite буду писать в java файле, а не xml, поэтому больше ничего подключать на нужно. 
Добавил датасет, и классы из примера отвечающие за загрузку его в игнайт. Да, их надо по-хорошему грузить где-то заранее в игнайт, но это тема уже другого исследования.

За загрузку и извлечение данных из кеша будет отвечать реализация интерфейса CacheManager. 

public interface CacheManager {

void fillTrain();

void fillTest();

IgniteCache<Integer, Vector> getTrainCache();

IgniteCache<Integer, Vector> getTestCache();

Ignite getIgnite();

void destroyAll();
}

Буду тренировать на полном датасете в 60000 объетов. Мой комп вполне выдерживает, ничего убавлять не буду. Результаты серии тренирововк буду запусывать в файл.

Алгоритм.

Тренировка и предсказание в классе наследнике PredictService,

public abstract class PredictService<T extends PredictService.Params> {

public abstract void fitAndPredict(CacheManager cacheService, T params);
    ...
}

Предполагается добавление других аглоритмов, наследующих этот класс и реализующих свою тренировку со своими параметрами.  

Пара слов о svm. Это алгоритм бинарной классификации, и чтобы применить его в мультиклассовой, используют подход One-vs-Rest и One-vs-one. В случае One-vs-Rest аглоритм разбивает множество на серию таких наборов [искомый_класс_1] и [все_остальное], [искомый_класс_2] и [все_остальное] и т.д. Подробней тут.

Сама магия сосредоточена в SVMPredictService:

public class SVMPredictService extends PredictService<SVMParams> {

@Override
public void fitAndPredict(CacheManager cacheManager, SVMParams params) {
Ignite ignite = cacheManager.getIgnite();
IgniteCache<Integer, Vector> cache = cacheManager.getTrainCache();

System.out.println(">>> Start fit model...");
long startMs = System.currentTimeMillis();

var preprocessor = createPreprocessor(ignite, cache);
var oneVsRestTrainer = new OneVsRestTrainer<>(createSVMTrainer(params));
MultiClassModel<SVMLinearClassificationModel> model = oneVsRestTrainer.fit(ignite,
cache,
preprocessor);

double timeFitSec = (System.currentTimeMillis() - startMs) / 1000.0;
System.out.printf(">>> Complete fit model, time = %.4f sec, params = %s \n", timeFitSec, params);

double accuracy = predict(model, cacheManager);

statisticWrite(params, accuracy, cache.size(), timeFitSec);
}
    ...
}

Тут происходит тренировка модели. Далее вызывается метод predict где берем тестовый датасет и тренированную модель, предсказываем и считаем точность.

Поиск оптимальных параметров.

Запуск 1.

Запускаю с различными значениям параметров. В целом точность колеблется в районе 0.7.
Первый эксперимент с диапазоном величин:
30 <= iterations <= 60,
60 <= locIterations <= 180,
0.2 <= lambda <= 0.8

Оставляю в таблице лучшие результаты с точностью выше 0.77. 

Date,Iterations,LocIterations,Lambda,TrainSize,TimeFit(sec),Accuracy
2023-01-27 22:09:12, 30, 120, 0.40, 60000, 58.67, 0.7763
2023-01-27 22:15:18, 30, 180, 0.60, 60000, 54.14, 0.7792
2023-01-27 22:18:34, 30, 150, 0.80, 60000, 52.26, 0.7710

2023-01-27 22:25:45, 45, 90, 0.40, 60000, 61.84, 0.7724
2023-01-27 22:28:34, 45, 180, 0.40, 60000, 58.44, 0.7797
2023-01-27 22:31:15, 45, 120, 0.60, 60000, 58.30, 0.7764
2023-01-27 22:32:09, 45, 150, 0.60, 60000, 53.75, 0.7792
2023-01-27 22:34:38, 45, 90, 0.80, 60000, 47.34, 0.7737
2023-01-27 22:35:28, 45, 120, 0.80, 60000, 50.32, 0.7811
2023-01-27 22:36:23, 45, 150, 0.80, 60000, 54.30, 0.7748
2023-01-27 22:37:21, 45, 180, 0.80, 60000, 58.13, 0.7961

2023-01-27 22:38:10, 60, 60, 0.20, 60000, 48.03, 0.7723
2023-01-27 22:43:17, 60, 60, 0.40, 60000, 49.05, 0.7855
2023-01-27 22:44:15, 60, 90, 0.40, 60000, 57.46, 0.7748
2023-01-27 22:49:18, 60, 90, 0.60, 60000, 56.32, 0.7921
2023-01-27 22:50:20, 60, 120, 0.60, 60000, 61.15, 0.7842
2023-01-27 22:51:26, 60, 150, 0.60, 60000, 65.72, 0.7967
2023-01-27 22:52:34, 60, 180, 0.60, 60000, 68.28, 0.7701
2023-01-27 22:55:23, 60, 120, 0.80, 60000, 61.83, 0.7900
2023-01-27 22:56:24, 60, 150, 0.80, 60000, 61.00, 0.7959
2023-01-27 22:57:27, 60, 180, 0.80, 60000, 62.70, 0.7877

Наблидаем при увеличении iterations и других параметров, число хороших результатов увеличивается. Очевидно параметр iterations можно увеличить. Так же LocIterations и Lambda. 

Запуск 2.

Изменю диапазоны параметров, теперь прогоняю для: 
60 <= iterations <= 90,
120 <= locIterations <= 180,
0.5 <= lambda <= 0.9
Date,Iterations,LocIterations,Lambda,TrainSize,TimeFit(sec),Accuracy
2023-01-28 19:10:36, 60, 120, 0.50, 60000, 81.90, 0.7844
2023-01-28 19:11:42, 60, 150, 0.50, 60000, 65.44, 0.7631
2023-01-28 19:13:01, 60, 180, 0.50, 60000, 78.13, 0.7383
2023-01-28 19:13:59, 60, 120, 0.70, 60000, 58.41, 0.7870
2023-01-28 19:15:01, 60, 150, 0.70, 60000, 61.52, 0.7777
2023-01-28 19:16:10, 60, 180, 0.70, 60000, 68.58, 0.7698
2023-01-28 19:17:12, 60, 120, 0.90, 60000, 62.00, 0.7914
2023-01-28 19:18:22, 60, 150, 0.90, 60000, 69.56, 0.7976
2023-01-28 19:19:45, 60, 180, 0.90, 60000, 82.86, 0.7939
2023-01-28 19:20:58, 75, 120, 0.50, 60000, 72.82, 0.7718
2023-01-28 19:22:14, 75, 150, 0.50, 60000, 75.14, 0.7976
2023-01-28 19:23:35, 75, 180, 0.50, 60000, 81.58, 0.7713
2023-01-28 19:24:45, 75, 120, 0.70, 60000, 69.24, 0.7995
2023-01-28 19:26:05, 75, 150, 0.70, 60000, 79.67, 0.8073
2023-01-28 19:27:32, 75, 180, 0.70, 60000, 86.56, 0.7961
2023-01-28 19:28:36, 75, 120, 0.90, 60000, 63.52, 0.7969
2023-01-28 19:29:41, 75, 150, 0.90, 60000, 65.18, 0.8048
2023-01-28 19:30:56, 75, 180, 0.90, 60000, 74.20, 0.7876
2023-01-28 19:32:00, 90, 120, 0.50, 60000, 64.35, 0.8007
2023-01-28 19:33:14, 90, 150, 0.50, 60000, 73.57, 0.7566
2023-01-28 19:34:32, 90, 180, 0.50, 60000, 77.42, 0.7509
2023-01-28 19:35:37, 90, 120, 0.70, 60000, 65.33, 0.7802
2023-01-28 19:36:51, 90, 150, 0.70, 60000, 73.39, 0.7738
2023-01-28 19:38:17, 90, 180, 0.70, 60000, 85.89, 0.7705
2023-01-28 19:39:33, 90, 120, 0.90, 60000, 75.59, 0.7909
2023-01-28 19:40:54, 90, 150, 0.90, 60000, 81.05, 0.7916
2023-01-28 19:42:16, 90, 180, 0.90, 60000, 81.27, 0.8134

Несколько результатов перевалили 0.8! Чтож, идем дальше.

Запуск 3.

90 <= iterations <= 120,
180 <= locIterations <= 210,
0.85 <= lambda <= 1.15

Date,Iterations,LocIterations,Lambda,TrainSize,TimeFit(sec),Accuracy
2023-01-29 12:31:26, 90, 180, 0.85, 60000, 93.90, 0.7799
2023-01-29 12:32:46, 90, 195, 0.85, 60000, 80.17, 0.7791
2023-01-29 12:34:09, 90, 210, 0.85, 60000, 82.28, 0.8007
2023-01-29 12:35:21, 90, 180, 1.00, 60000, 71.93, 0.8104
2023-01-29 12:36:39, 90, 195, 1.00, 60000, 78.42, 0.8123
2023-01-29 12:37:59, 90, 210, 1.00, 60000, 79.16, 0.7845
2023-01-29 12:39:20, 90, 180, 1.15, 60000, 81.12, 0.8104
2023-01-29 12:40:47, 90, 195, 1.15, 60000, 86.49, 0.8106
2023-01-29 12:42:14, 90, 210, 1.15, 60000, 87.02, 0.8155
2023-01-29 12:43:40, 105, 180, 0.85, 60000, 85.45, 0.8104
2023-01-29 12:45:03, 105, 195, 0.85, 60000, 82.65, 0.8158
2023-01-29 12:46:38, 105, 210, 0.85, 60000, 94.75, 0.7681
2023-01-29 12:48:01, 105, 180, 1.00, 60000, 83.18, 0.7818
2023-01-29 12:49:22, 105, 195, 1.00, 60000, 81.07, 0.7776
2023-01-29 12:50:53, 105, 210, 1.00, 60000, 90.66, 0.8112
2023-01-29 12:52:15, 105, 180, 1.15, 60000, 81.24, 0.7875
2023-01-29 12:53:44, 105, 195, 1.15, 60000, 88.50, 0.7878
2023-01-29 12:55:17, 105, 210, 1.15, 60000, 92.72, 0.7800
2023-01-29 12:56:50, 120, 180, 0.85, 60000, 93.43, 0.7560
2023-01-29 12:58:26, 120, 195, 0.85, 60000, 95.38, 0.7565
2023-01-29 13:00:10, 120, 210, 0.85, 60000, 104.20, 0.8012
2023-01-29 13:01:37, 120, 180, 1.00, 60000, 86.75, 0.8116
2023-01-29 13:03:10, 120, 195, 1.00, 60000, 92.24, 0.8138
2023-01-29 13:04:47, 120, 210, 1.00, 60000, 96.52, 0.7745
2023-01-29 13:06:19, 120, 180, 1.15, 60000, 91.86, 0.8139
2023-01-29 13:07:51, 120, 195, 1.15, 60000, 92.08, 0.8142
2023-01-29 13:09:33, 120, 210, 1.15, 60000, 101.41, 0.8164

Теперь лучшие превышают 0.815. Достаточно много результатов превышающих 0.81.

Про параметр lambda. Я так и не нашел в каком диапазоне его нужно задавать, и что будет при значениях больше 1. Как видим и при большем 1, алгоритм выполняется и даже показывает лучший результат. Но дальнейшие эксперименты показали что значительное его увеличение только ухудшит точность. 

Дальнейшие эксперименты уже менее результативны, еще немного увеличиваю параметры и запускаю последнюю серию.

Запуск 4. 

110 <= iterations <= 140
210 <= locIterations <= 230
0.9 <= lambda <= 1.1
Последняя серия запусков внушительна и долго выполняется.
Date,Iterations,LocIterations,Lambda,TrainSize,TimeFit(sec),Accuracy
2023-01-29 16:41:01, 110, 210, 0.90, 60000, 99.82, 0.7601
2023-01-29 16:42:38, 110, 220, 0.90, 60000, 96.51, 0.8115
2023-01-29 16:44:20, 110, 230, 0.90, 60000, 101.71, 0.8046
2023-01-29 16:45:50, 110, 210, 0.95, 60000, 90.32, 0.7646
2023-01-29 16:47:28, 110, 220, 0.95, 60000, 97.06, 0.8122
2023-01-29 16:49:05, 110, 230, 0.95, 60000, 97.47, 0.7512
2023-01-29 16:52:15, 110, 210, 0.95, 60000, 105.01, 0.7646
2023-01-29 16:53:55, 110, 220, 0.95, 60000, 100.05, 0.8122
2023-01-29 16:55:31, 110, 230, 0.95, 60000, 95.32, 0.7512
2023-01-29 16:56:59, 110, 210, 1.00, 60000, 87.98, 0.7651
2023-01-29 16:58:27, 110, 220, 1.00, 60000, 88.00, 0.8182
2023-01-29 17:00:01, 110, 230, 1.00, 60000, 93.29, 0.8140
2023-01-29 17:01:30, 110, 210, 1.05, 60000, 89.24, 0.7772
2023-01-29 17:03:04, 110, 220, 1.05, 60000, 93.38, 0.8176
2023-01-29 17:04:35, 110, 230, 1.05, 60000, 90.60, 0.8221
2023-01-29 17:06:04, 110, 210, 1.10, 60000, 88.70, 0.7769
2023-01-29 17:07:34, 110, 220, 1.10, 60000, 90.50, 0.8181
2023-01-29 17:09:15, 110, 230, 1.10, 60000, 100.70, 0.7739
2023-01-29 17:10:52, 125, 210, 0.95, 60000, 96.40, 0.8135
2023-01-29 17:12:32, 125, 220, 0.95, 60000, 99.55, 0.7610
2023-01-29 17:14:13, 125, 230, 0.95, 60000, 100.70, 0.8191
2023-01-29 17:15:44, 125, 210, 1.00, 60000, 91.36, 0.8062
2023-01-29 17:17:19, 125, 220, 1.00, 60000, 94.74, 0.7627
2023-01-29 17:19:04, 125, 230, 1.00, 60000, 104.12, 0.7560
2023-01-29 17:20:39, 125, 210, 1.05, 60000, 95.46, 0.8045
2023-01-29 17:22:13, 125, 220, 1.05, 60000, 93.98, 0.7646
2023-01-29 17:23:50, 125, 230, 1.05, 60000, 96.65, 0.7640
2023-01-29 17:25:26, 125, 210, 1.10, 60000, 95.16, 0.8086
2023-01-29 17:27:05, 125, 220, 1.10, 60000, 99.18, 0.7726
2023-01-29 17:28:48, 125, 230, 1.10, 60000, 102.42, 0.8167
2023-01-29 17:30:29, 140, 210, 0.95, 60000, 101.26, 0.7463
2023-01-29 17:32:18, 140, 220, 0.95, 60000, 108.89, 0.8122
2023-01-29 17:34:08, 140, 230, 0.95, 60000, 109.26, 0.7412
2023-01-29 17:36:00, 140, 210, 1.00, 60000, 112.27, 0.7447
2023-01-29 17:37:57, 140, 220, 1.00, 60000, 116.47, 0.8138
2023-01-29 17:39:55, 140, 230, 1.00, 60000, 117.74, 0.8040
2023-01-29 17:41:41, 140, 210, 1.05, 60000, 105.95, 0.7518
2023-01-29 17:43:27, 140, 220, 1.05, 60000, 105.17, 0.8106
2023-01-29 17:45:14, 140, 230, 1.05, 60000, 107.28, 0.8143
2023-01-29 17:46:56, 140, 210, 1.10, 60000, 101.40, 0.7549
2023-01-29 17:48:42, 140, 220, 1.10, 60000, 106.20, 0.8133
2023-01-29 17:50:31, 140, 230, 1.10, 60000, 108.67, 0.7542
Зато тут лучший результат получается 0,82.
Маловато конечно. Прочитав эту статью про mnist и svm убедился вроде как норм. Там тоже получается с svm точность 0.84. Остановимся на этом.

Итоги

Результат с точностью в районе 0.8 честно говоря меня не впечатлил. Полагаю данный алгоритм в идеале предназначен для несколько других датасетов и задач. А классификацию изображений оставим для нейросетей. 

Но считаю данный опыт полезным для изучения ml. Знакомство с apache ignite ml считаю успешным. Порадовало множетсво примеров в репозитории и оф дока, только хотелось бы чуть больше подробностей там, особенно про параметры. 
Собрал и запустил рабочий пример, поэкспериментировал с параметрами, что только увеличило интерес к дальшейшему изучению ml.

--- 

Github получившегося примера.

Ресурсы: 
Apache Ignite ML documentation оф.документация, раздел про ml.
Статья на хабре с которой началось знакомство.
OneVsResrClassification пример классификации из официального репозитория Apache Ignite.
MNIST in CSV датасет c kaggle.
Про аглоритм SVM краткое описание на хабре.
One-vs-Rest, One-vs-one описание подходов при мульклассовой классификации.
Multiclass classification with SVM пример на питоне

0 comments:

Отправить комментарий