Newer
Older
ai / lab8 / id3.py
@Andreas Jaggi Andreas Jaggi on 29 May 2006 13 KB Added Lab8, an Decision Tree System (ID3)
# -*- coding: cp1252 -*-
#################################################################################
print "Chargement du module ID3"
#################################################################################

import math
import copy
import sys

## La liste des exemples, selon le format suivant:
##    exemples ::= [ { exemple } ]
##    exemple ::= [ classe valeur-attribut-1 ... valeur-attribut-n]
allExamples = []

## La liste des attributs (dans le meme ordre que celui des exemples)
## avec leurs valeurs. Format:
##    attributs ::= [ classes { attribut } ]
##    classes ::= [ nom-attribut classe-1 .. classe-m ]
##    attribut ::= [ nom-attribut valeur-1 .. valeur-n ]
attributesAndValues = []

## L'arbre de decision trouve:
classification = []


###===================================================================
###   INITIALISATION et CONSULTATION DE L'ARBRE DE DECISION
###===================================================================

## Initialise le module ID3 avec les exemples <examples> et la
## description des classes et attributs <attributes-and-values>.
def initID3(examples, theAttributesAndValues):
	global allExamples
	global attributesAndValues
	global classification
	allExamples = examples
	attributesAndValues = theAttributesAndValues
	classification = []
	print("DB: initID3 -> attributesAndValues: %s" % (attributesAndValues))

## Permet de classer un exemple a l'aide de l'arbre de decision
## classification, calculé par buildDecisionTree.
def classify(n=classification):
	if n == []:
		## On ne sait pas classer cet exemple
		print("L'exemple ne peut être classé")
	elif n.terminalNode():
		## L'exemple peut être classé:
		print(">>> L'exemple est de la classe %s" % (n.getClass()))
	else:
		## L'exemple ne peut encore être classé, il faut des informations
		## complémentaires:
		splitAttribute = n.attribute
		validValues = getAllAttributeValues(splitAttribute)
		while True:
			print("> Valeur de l'attribut %s? " % (attributeName(splitAttribute)))
			value = sys.stdin.readline().strip()
			if value in validValues:
				return classify(n.children[validValues.index(value)])
			print("## Valeur %s inconnue pour l'attribut %s" %
				  (value, attributeName(splitAttribute)))



###===================================================================
###   CALCUL DE L'ARBRE DE DECISION
###===================================================================

class Node:
	## Constructeur
	## @param attribute L'attribut de partitionnement du noeud.
	## @param examples Les exemples du noeud (s'il est terminal).
	## @param children Les noeuds fils (s'il n'est pas terminal).
	def __init__(self, attribute, examples, children):
		self.attribute = attribute
		self.examples = examples
		self.children = children

	def __str__(self):
		out = ("\n********** DECISION TREE **********\n")
		return out + self.printNodeAux("  ")

	def printNodeAux(self, space):
		if self.terminalNode():
			## Noeud terminal:
			out = ("%s Noeud terminal: classe %s\n" % (space, self.getClass()))
			for ex in self.examples:
				out = out + ("%s  %s\n" % (space, ex))
		else:
			## Noeud non terminal:
			out = ("%s Partitionnement selon %s\n" % (space, attributeName(self.attribute)))
			values = getAllAttributeValues(self.attribute)
			successors = self.children
			newSpace = space + "	"
			while len(successors) > 0:
				out = out + ("%s  Valeur %s\n" % (space, values[0]))
				out = out + successors[0].printNodeAux(newSpace)
				values = values[1:]
				successors = successors[1:]
		return out

	## Retourne la classe correspondant au noeud <n> s'il est terminal,
	## [] si c'est un noeud intermediaire
	def getClass(self):
		if self.terminalNode():
			return getClassValue(self.examples[0])

	## Retourne True ssi self est un noeud terminal, càd s'il n'a pas
	## d'attribut de partitionnement
	def terminalNode(self):
		return (self.attribute == [])


## Construit l'arbre de decision pour les exemples de allExamples et
## la sauvegarde dans la variable globale classification.
def buildDecisionTree():
	global allExamples
	global classification
	classification = buildDecisionTreeAux(allExamples, buildAttributesList())
	#print(classification)

def buildDecisionTreeAux(examples, attributes):
	if not examples:
		return []

	if oneClass(examples):
		return Node([], examples, [])

	a = bestAttribute(attributes, examples)

	part = partition(examples, a)

	fils = []
	for p in part:
		fils.append(buildDecisionTreeAux(p, list(set(attributes).difference(set([a])))))

	return Node(a, [], fils)

##===================================================================
##   CALCUL DE L'ENTROPIE
##===================================================================

## Retourne le meilleur attribut, càd celui avec l'entropie
## (càd l'incertitude) la plus faible.

def bestAttribute(attributes, examples):
	print("DB: bestAttribute -> attributes: %s" % (attributes))
	return attributes[getIndexOfMinimum(map(lambda attribute : entropy(attribute, examples), attributes))]

## Retourne l'entropie totale pour <attribute>, a savoir:
##	  H(C|A) = H(C|attribute).
def entropy(attribute, examples):
	s = 0.0

	for v in getAllAttributeValues(attribute):
		s = s + P_Aj(attribute, v, examples) * H_C_Aj(attribute, v, examples)

	return s

## Retourne l'entropie conditionnelle de <attribute> avec pour
## valeur <value>, en d'autres termes:
##	  H(C|Aj) = H(C|attribute=value).
def H_C_Aj(attribute, value, examples):
	s = 0.0

	for c in getPossibleClassValues(examples):
		t = P_Ci_Aj(c, attribute, value, examples)
		if t != 0:
			s = s + t * math.log(t, 2)

	return -s

def calculateProb(classValue,attribute, value, examples):
	prob = P_Ci_Aj(classValue, attribute, value, examples)
	#print("DB: calculateProb -> prob: %s" % (prob))
	if prob == 0:
		return 0
	else:
		return prob * math.log(prob,2)

## Calcule la probabilite que la valeur de la classe est
## <class-value> lorsque <attribut> vaut <value>, en d'autres
## termes: P(Ci|Aj) = P(class-value|attribute=value).
def P_Ci_Aj(classValue, attribute, value, examples):

	ct = 0.0
	for e in examples:
		if getClassValue(e) == classValue:
			if getAttributeValue(e, attribute) == value:
				ct = ct+1
	pAB = ct/len(examples)
	pB = P_Aj(attribute, value, examples)

	if pB > 0:
		return pAB/pB
	else:
		return 0

## Retourne la probabilite que <attribut> vaut <value>.
def P_Aj(attribute, value, examples):
	return countOccurences(attribute, value, examples)/len(examples)

##===================================================================
##   EXEMPLES
##===================================================================

## Retourne la valeur de l'attribut d'index <attribute> pour
## <example>.
def getAttributeValue(example, attribute):
	return example[attribute]

## Retourne la classe de <example>.
def getClassValue(example):
	return example[0]


##===================================================================
##   OPERATIONS SUR UNE LISTE D'EXEMPLES
##===================================================================

## Partitionne les <examples> selon la valeur de <attribute>.
## Retourne une liste tq chaque element est une liste des exemples
## avec un valeur commune pour <attribute>.
def partition(examples, attribute):
	r = []
	for v in getAllAttributeValues(attribute):
		r.append(findAll(examples, lambda e: (getAttributeValue(e, attribute) == v)))
	return r

## Retourne True ssi tous les exemples de <examples> font partie de la
## meme classe.
def oneClass(examples):
	res = len(getPossibleClassValues(examples)) == 1
	print("DB: --------------------- oneClass -> getPossibleClassValues(examples) %s" % (getPossibleClassValues(examples)))
	print("DB: --------------------- oneClass -> res %s" % (res))
	return res

## Retourne la liste de toutes les valeurs pour l'attribut d'index
## <attribute> que l'on trouve dans les exemples de <examples>.
def getPossibleAttributeValues(attribute, examples):
	res = []
	for ex in examples:
		value = getAttributeValue(ex,attribute)
		if not value in res:
			res = [value] + res
	return res

## Retourne la liste de toutes les classes que l'on trouve dans
## <examples>.
def getPossibleClassValues(examples):
	res= getPossibleAttributeValues(0,examples)
	#print("DB: getPossibleClassValues -> res: %s" % (res))
	return res

## Retourne le nombre d'exemples de <examples> dont l'attribut
## d'index <attribute> vaut <value>.
def countOccurences(attribute, value, examples):
	res = 0
	for ex in examples:
		if getAttributeValue(ex,attribute) == value:
			res = res + 1
#	if res <> 0:
#		print("DB: countOccurences -> res: DIFFERENT FROM ZERO!!!!!!!!!! %s" % (res))
	return res

## Retourne le nombre d'exemples de <examples> appartenant a la
## classe <class-value> et qui ont <value> comme valeur de
## l'attribut d'index <attribute>.
def countOccurencesIf(classValue, attribute, value, examples):
	res = 0
	for ex in examples:
		if (getClassValue(ex) == classValue) and (getAttributeValue(ex,attribute) == value):
			res = res + 1
	return res


##===================================================================
##   UTILITAIRES
##===================================================================

## Retourne la liste de toutes les valeurs possibles de l'attribut
## d'index <attribute>.
def getAllAttributeValues(attribute):
	global attributesAndValues
	return attributesAndValues[attribute][1:]

## ATTENTION!!!!!!!!!!!!!!!!!!!!!!!
## VERIFIER QUE LES INDEXS SONT CORRECTS


## Retourne la liste des index des attributs, i.e. une liste de 1 a
## last-1.
def buildAttributesList(current=1,last=-1):
	if last == -1:
		last = len(attributesAndValues)
	if current < last:
		return [current] + buildAttributesList(current + 1, last)
	else:
		return []

## Retourne le nom de l'attribut d'index <attribute>.
def attributeName(attribute):
	global attributesAndValues
	return attributesAndValues[attribute][0]

## Retourne l'index de la plus petite valeur de la liste <l>.
def getIndexOfMinimum(l):
	print("DB: getIndexOfMinimum -> l: %s" % (l))
	return l.index(min(l))

## Retourne la liste de tous les elements de la liste <l> qui
## satisfont <test>.
def findAll(l, test):
	res = []
	for item in l:
		if test(item):
			res = [item] + res
	return res



## TESTING -----------------------------------------
## Construction d'un arbre de decision (ID3).
## Exemple: estimation du profit d'une entreprise informatique.
def test1():
	initID3([['down',  'old',  'no', 'software'],
	 ['down','midlife','yes','software'],
	 ['up',  'midlife','no', 'hardware'],
	 ['down','old',    'no', 'hardware'],
	 ['up',  'new',    'no', 'hardware'],
	 ['up',  'new',    'no', 'software'],
	 ['up',  'midlife','no', 'software'],
	 ['up',  'new',    'yes','software'],
	 ['down','midlife','yes','hardware'],
	 ['down','old',    'yes','software']],
	[['profit', 'down','up'],
	 ['age', 'old', 'midlife', 'new'],
	 ['competition', 'no', 'yes'],
	 ['type', 'software', 'hardware']])
	buildDecisionTree()


## Construction d'un arbre de decision (ID3).
## Exemple: maladies chez l'enfant

def test2():
	initID3([['angine-erythemateuse','elevee','gonflees',    'oui','oui','non','non','non','normale','normales',     'normaux'],
		 ['angine-pultacee',     'elevee','point-blancs','oui','oui','non','non','non','normale','normales',     'normaux'],
		 ['angine-diphterique',  'legere','enduit-blanc','oui','oui','non','non','non','normale','normales',     'normaux'],
		 ['appendicite',         'legere','normales',    'non','non','oui','non','non','normale','normales',     'normaux'],
		 ['bronchite',           'legere','normales',    'oui','non','non','oui','oui','gene',   'normales',     'normaux'],
		 ['coqueluche',          'legere','normales',    'non','oui','non','oui','oui','gene',   'normales',     'normaux'],
		 ['pneumonie',           'elevee','normales',    'non','non','non','oui','non','rapide', 'rouges',       'normaux'],
		 ['rougeole',            'legere','normales',    'non','oui','non','oui','oui','normale','normales',     'larmoyants'],
		 ['rougeole',            'legere','normales',    'non','oui','non','oui','oui','normale','taches-rouges','larmoyants'],
		 ['rubeole',             'legere','normales',    'oui','non','non','non','non','normale','taches-rouges','normaux'],
		 ['rubeole',             'non',   'normales',    'oui','non','non','non','non','normale','taches-rouges','normaux'],
		 ['rubeole',             'non',   'normales',    'oui','non','non','non','non','normale','normales',	 'normaux']],
		[['maladies','angine-erythemateuse','angine-pultacee','angine-diphterique','appendicite','bronchite','coqueluche','pneumonie','rougeole','rubeole'],
		 ['fievre','non','legere','elevee'],
		 ['amygdales','normales','gonflees', 'points-blancs', 'enduit-blanc'],
		 ['ganglions','non','oui'],
		 ['gene-a-avaler','non','oui'],
		 ['mal-au-ventre','non','oui'],
		 ['toux','non','oui'],
		 ['rhume','non','oui'],
		 ['respiration','normale','gene','rapide'],
		 ['joues','normales','rouges','taches-rouges'],
		 ['yeux','normaux','larmoyants']])
	buildDecisionTree()




if __name__ == '__main__':
	#test1()
	test2()
	#execfile('invest.py')
	classify(classification)