Commit 8cb90066 authored by Vlad Apetrei's avatar Vlad Apetrei Committed by David Monllaó
Browse files

MDL-58992 analytics: add multi-classification to predictors

Adds multi-class capabilities to prediction processors
as well as multi-classification unit tests
parent a672f021
<?php
// This file is part of Moodle - http://moodle.org/
//
// Moodle is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Moodle is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Moodle. If not, see <http://www.gnu.org/licenses/>.
/**
* Multiclass test indicator.
*
* @package core_analytics
* @copyright 2019 Vlad Apetrei
* @license http://www.gnu.org/copyleft/gpl.html GNU GPL v3 or later
*/
defined('MOODLE_INTERNAL') || die();
/**
* Multiclass test indicator.
*
* @package core_analytics
* @copyright 2019 Vlad Apetrei
* @license http://www.gnu.org/copyleft/gpl.html GNU GPL v3 or later
*/
class test_indicator_multiclass extends \core_analytics\local\indicator\linear {
/**
* Returns a lang_string object representing the name for the indicator.
*
* Used as column identificator.
*
* If there is a corresponding '_help' string this will be shown as well.
*
* @return \lang_string
*/
public static function get_name() : \lang_string {
// Using a string that exists and contains a corresponding '_help' string.
return new \lang_string('allowstealthmodules');
}
/**
* include_averages
*
* @return bool
*/
protected static function include_averages() {
return false;
}
/**
* required_sample_data
*
* @return string[]
*/
public static function required_sample_data() {
return array('course');
}
/**
* calculate_sample
*
* @param int $sampleid
* @param string $samplesorigin
* @param int $starttime
* @param int $endtime
* @return float
*/
protected function calculate_sample($sampleid, $samplesorigin, $starttime, $endtime) {
$course = $this->retrieve('course', $sampleid);
$firstchar = substr($course->fullname, 0, 1);
if ($firstchar === 'a') {
return 1;
} else if ($firstchar === 'b') {
return -1;
} else if ($firstchar === 'c') {
return 1;
} else {
return self::MAX_VALUE;
}
}
}
<?php
// This file is part of Moodle - http://moodle.org/
//
// Moodle is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Moodle is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Moodle. If not, see <http://www.gnu.org/licenses/>.
/**
* Multi-class classifier target.
*
* @package core_analytics
* @copyright 2019 Apetrei Vlad
* @license http://www.gnu.org/copyleft/gpl.html GNU GPL v3 or later
*/
defined('MOODLE_INTERNAL') || die();
/**
* Multi-class classifier target.
*
* @package core_analytics
* @copyright 2019 Apetrei Vlad
* @license http://www.gnu.org/copyleft/gpl.html GNU GPL v3 or later
*/
class test_target_shortname_multiclass extends \core_analytics\local\target\discrete {
/**
* Returns a lang_string object representing the name for the indicator.
*
* Used as column identificator.
*
* If there is a corresponding '_help' string this will be shown as well.
*
* @return \lang_string
*/
public static function get_name() : \lang_string {
// Using a string that exists and contains a corresponding '_help' string.
return new \lang_string('allowstealthmodules');
}
/**
* predictions
*
* @var array
*/
protected $predictions = array();
/**
* is_linear
*
* @return bool
*/
public function is_linear() {
return false;
}
/**
* Returns the target discrete values.
*
* Only useful for targets using discrete values, must be overwriten if it is the case.
*
* @return array
*/
public static final function get_classes() {
return array(0, 1, 2);
}
/**
* Is the calculated value a positive outcome of this target?
*
* @param string $value
* @param string $ignoredsubtype
* @return int
*/
public function get_calculation_outcome($value, $ignoredsubtype = false) {
if (!self::is_a_class($value)) {
throw new \moodle_exception('errorpredictionformat', 'analytics');
}
if (in_array($value, $this->ignored_predicted_classes(), false)) {
// Just in case, if it is ignored the prediction should not even be recorded but if it would, it is ignored now,
// which should mean that is it nothing serious.
return self::OUTCOME_VERY_POSITIVE;
}
// By default binaries are danger when prediction = 1.
if ($value) {
return self::OUTCOME_VERY_NEGATIVE;
}
return self::OUTCOME_VERY_POSITIVE;
}
/**
* get_analyser_class
*
* @return string
*/
public function get_analyser_class() {
return '\core\analytics\analyser\site_courses';
}
/**
* We don't want to discard results.
* @return float
*/
protected function min_prediction_score() {
return null;
}
/**
* We don't want to discard results.
* @return array
*/
public function ignored_predicted_classes() {
return array();
}
/**
* is_valid_analysable
*
* @param \core_analytics\analysable $analysable
* @param bool $fortraining
* @return bool
*/
public function is_valid_analysable(\core_analytics\analysable $analysable, $fortraining = true) {
// This is testing, let's make things easy.
return true;
}
/**
* is_valid_sample
*
* @param int $sampleid
* @param \core_analytics\analysable $analysable
* @param bool $fortraining
* @return bool
*/
public function is_valid_sample($sampleid, \core_analytics\analysable $analysable, $fortraining = true) {
// We skip not-visible courses during training as a way to emulate the training data / prediction data difference.
// In normal circumstances is_valid_sample will return false when they receive a sample that can not be
// processed.
if (!$fortraining) {
return true;
}
$sample = $this->retrieve('course', $sampleid);
if ($sample->visible == 0) {
return false;
}
return true;
}
/**
* classes_description
*
* @return string[]
*/
protected static function classes_description() {
return array(
get_string('first class'),
get_string('second class'),
get_string('third class')
);
}
/**
* calculate_sample
*
* @param int $sampleid
* @param \core_analytics\analysable $analysable
* @param int $starttime
* @param int $endtime
* @return float
*/
protected function calculate_sample($sampleid, \core_analytics\analysable $analysable, $starttime = false, $endtime = false) {
$sample = $this->retrieve('course', $sampleid);
$firstchar = substr($sample->shortname, 0, 1);
switch ($firstchar) {
case 'a':
return 0;
case 'b':
return 1;
case 'c':
return 2;
}
}
/**
* Can the provided time-splitting method be used on this target?.
*
* Time-splitting methods not matching the target requirements will not be selectable by models based on this target.
*
* @param \core_analytics\local\time_splitting\base $timesplitting
* @return bool
*/
public function can_use_timesplitting(\core_analytics\local\time_splitting\base $timesplitting):bool {
return true;
}
}
......@@ -30,7 +30,9 @@ require_once(__DIR__ . '/fixtures/test_indicator_min.php');
require_once(__DIR__ . '/fixtures/test_indicator_null.php');
require_once(__DIR__ . '/fixtures/test_indicator_fullname.php');
require_once(__DIR__ . '/fixtures/test_indicator_random.php');
require_once(__DIR__ . '/fixtures/test_indicator_multiclass.php');
require_once(__DIR__ . '/fixtures/test_target_shortname.php');
require_once(__DIR__ . '/fixtures/test_target_shortname_multiclass.php');
require_once(__DIR__ . '/fixtures/test_static_target_shortname.php');
require_once(__DIR__ . '/../../course/lib.php');
......@@ -433,6 +435,70 @@ class core_analytics_prediction_testcase extends advanced_testcase {
return $this->add_prediction_processors($cases);
}
/**
* Tests correct multi-classification.
*
* @dataProvider provider_test_multi_classifier
* @param string $timesplittingid
* @param string $predictionsprocessorclass
* @throws coding_exception
* @throws moodle_exception
*/
public function test_ml_multi_classifier($timesplittingid, $predictionsprocessorclass) {
global $DB;
$this->resetAfterTest(true);
$this->setAdminuser();
set_config('enabled_stores', 'logstore_standard', 'tool_log');
$predictionsprocessor = \core_analytics\manager::get_predictions_processor($predictionsprocessorclass, false);
if ($predictionsprocessor->is_ready() !== true) {
$this->markTestSkipped('Skipping ' . $predictionsprocessorclass . ' as the predictor is not ready.');
}
// Generate training courses.
$ncourses = 5;
$this->generate_courses_multiclass($ncourses);
$model = $this->add_multiclass_model();
$model->update(true, false, $timesplittingid, get_class($predictionsprocessor));
$results = $model->train();
$params = [
'startdate' => mktime(0, 0, 0, 10, 24, 2015),
'enddate' => mktime(0, 0, 0, 2, 24, 2016),
];
$courseparams = $params + array('shortname' => 'aaaaaa', 'fullname' => 'aaaaaa', 'visible' => 0);
$course1 = $this->getDataGenerator()->create_course($courseparams);
$courseparams = $params + array('shortname' => 'bbbbbb', 'fullname' => 'bbbbbb', 'visible' => 0);
$course2 = $this->getDataGenerator()->create_course($courseparams);
$courseparams = $params + array('shortname' => 'cccccc', 'fullname' => 'cccccc', 'visible' => 0);
$course3 = $this->getDataGenerator()->create_course($courseparams);
// They will not be skipped for prediction though.
$result = $model->predict();
// The $course1 predictions should be 0 == 'a', $course2 should be 1 == 'b' and $course3 should be 2 == 'c'.
$correct = array($course1->id => 0, $course2->id => 1, $course3->id => 2);
foreach ($result->predictions as $uniquesampleid => $predictiondata) {
list($sampleid, $rangeindex) = $model->get_time_splitting()->infer_sample_info($uniquesampleid);
// The range index is not important here, both ranges prediction will be the same.
$this->assertEquals($correct[$sampleid], $predictiondata->prediction);
}
}
/**
* Provider for the multi_classification test.
*
* @return array
*/
public function provider_test_multi_classifier() {
$cases = array(
'notimesplitting' => array('\core\analytics\time_splitting\no_splitting'),
);
// Add all system prediction processors.
return $this->add_prediction_processors($cases);
}
/**
* Basic test to check that prediction processors work as expected.
*
......@@ -670,7 +736,6 @@ class core_analytics_prediction_testcase extends advanced_testcase {
* @return \core_analytics\model
*/
protected function add_perfect_model($targetclass = 'test_target_shortname') {
$target = \core_analytics\manager::get_target($targetclass);
$indicators = array('test_indicator_max', 'test_indicator_min', 'test_indicator_fullname');
foreach ($indicators as $key => $indicator) {
......@@ -683,6 +748,25 @@ class core_analytics_prediction_testcase extends advanced_testcase {
return new \core_analytics\model($model->get_id());
}
/**
* Generates model for multi-classification
*
* @param string $targetclass
* @return \core_analytics\model
* @throws coding_exception
* @throws moodle_exception
*/
public function add_multiclass_model($targetclass = 'test_target_shortname_multiclass') {
$target = \core_analytics\manager::get_target($targetclass);
$indicators = array('test_indicator_fullname', 'test_indicator_multiclass');
foreach ($indicators as $key => $indicator) {
$indicators[$key] = \core_analytics\manager::get_indicator($indicator);
}
$model = \core_analytics\model::create($target, $indicators);
return new \core_analytics\model($model->get_id());
}
/**
* Generates $ncourses courses
*
......@@ -709,6 +793,37 @@ class core_analytics_prediction_testcase extends advanced_testcase {
}
}
/**
* Generates ncourses for multi-classification
*
* @param int $ncourses The number of courses to be generated.
* @param array $params Course params
* @return null
*/
protected function generate_courses_multiclass($ncourses, array $params = []) {
$params = $params + [
'startdate' => mktime(0, 0, 0, 10, 24, 2015),
'enddate' => mktime(0, 0, 0, 2, 24, 2016),
];
for ($i = 0; $i < $ncourses; $i++) {
$name = 'a' . random_string(10);
$courseparams = array('shortname' => $name, 'fullname' => $name) + $params;
$this->getDataGenerator()->create_course($courseparams);
}
for ($i = 0; $i < $ncourses; $i++) {
$name = 'b' . random_string(10);
$courseparams = array('shortname' => $name, 'fullname' => $name) + $params;
$this->getDataGenerator()->create_course($courseparams);
}
for ($i = 0; $i < $ncourses; $i++) {
$name = 'c' . random_string(10);
$courseparams = array('shortname' => $name, 'fullname' => $name) + $params;
$this->getDataGenerator()->create_course($courseparams);
}
}
/**
* add_prediction_processors
*
......
......@@ -132,8 +132,7 @@ class processor implements \core_analytics\classifier, \core_analytics\regressor
$nsamples = count($samples);
if ($nsamples === self::BATCH_SIZE) {
// Training it batches to avoid running out of memory.
$classifier->partialTrain($samples, $targets, array(0, 1));
$classifier->partialTrain($samples, $targets, json_decode($metadata['targetclasses']));
$samples = array();
$targets = array();
}
......@@ -152,7 +151,7 @@ class processor implements \core_analytics\classifier, \core_analytics\regressor
// Train the remaining samples.
if ($samples) {
$classifier->partialTrain($samples, $targets, array(0, 1));
$classifier->partialTrain($samples, $targets, json_decode($metadata['targetclasses']));
}
$resultobj = new \stdClass();
......
......@@ -38,7 +38,7 @@ class processor implements \core_analytics\classifier, \core_analytics\regresso
/**
* The required version of the python package that performs all calculations.
*/
const REQUIRED_PIP_PACKAGE_VERSION = '2.0.0';
const REQUIRED_PIP_PACKAGE_VERSION = '2.1.0';
/**
* The path to the Python bin.
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment