協作閣

開源協作部落格

Applying logistic regression for classifying human-machine dialogue and human-human dialogue

Jessica / 2019-05-03 /


Logistic Regression雖然名為迴歸,但常⽤於分類(⼆元或多類別)

人與對話機器人的對話取自chatterbot訓練資料集,人與人之間的對話取自騰訊AI Lab 對話資料集,可參考這篇論文。從兩資料集各隨機取出50 組單輪對話(均沒有特定主題)。接下來擷取兩組對話的量化特徵(詞彙豐富度、對話長度、句子平均長度、虛詞使用比率、各詞類使用頻率等等)共24個,將這些特徵視為X, 是否為人機對話為Y (是:1, 否:0),藉此分類人人對話與人機對話。

from sklearn import preprocessing, linear_model
import pandas as pd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import f_regression
plt.style.use('ggplot')
plt.rcParams['font.family']='SimHei' #⿊體

讀入資料

data = pd.read_excel('sample_50_anno.xlsx',sheet_name = None) 
lrdata=data.get('lr_data') # get a specific sheet to DataFrame

用 ‘Text Length’,‘Different Words’,‘Entropy’,‘Simpson Index’,‘Sentence Count’,‘Sentence Length Average’,‘Sentence Length Variance’,‘Function Word Count’,‘Function Word proportion’,‘TTR’,‘Na’,‘Nb’,‘Nc’,‘Nd’,‘Nh’,’D’,’T’,‘VA’,‘VB’,‘VC’,‘VD’,’VE’,‘VH’,‘C’這些數值資料來預測是否為人機對話

df=lrdata[['Text Length','Different Words','Entropy','Simpson Index','Sentence Count','Sentence Length Average','Sentence Length Variance','Function Word Count','Function Word Proportion','TTR','Na','Nb','Nc','Nd','Nh','D','T','VA','VB','VC','VD','VE','VH','C','Machine']]
df.head()
Text Length Different Words Entropy Simpson Index Sentence Count Sentence Length Average Sentence Length Variance Function Word Count Function Word Proportion TTR ... D T VA VB VC VD VE VH C Machine
0 23 10 2.262386 0.107266 15 1.533333 0.466667 6 0.260870 0.435000 ... 0 0 0 0 0 0 1 2 0 0
1 29 14 2.553682 0.085000 18 1.611111 6.277778 5 0.172414 0.482759 ... 2 0 0 0 3 0 0 3 0 0
2 51 22 2.655133 0.129291 25 2.040000 7.240000 2 0.039216 0.431373 ... 2 0 0 0 3 0 0 2 0 0
3 12 8 2.043192 0.135802 7 1.714286 3.428571 0 0.000000 0.666667 ... 0 0 0 0 1 0 0 1 0 0
4 9 8 2.043192 0.135802 9 1.000000 0.000000 4 0.444444 0.888889 ... 2 1 0 0 0 0 0 0 0 0

5 rows × 25 columns

切分訓練、測試資料

x=df[['Text Length','Different Words','Entropy','Simpson Index','Sentence Count','Sentence Length Average','Sentence Length Variance','Function Word Count','Function Word Proportion','TTR','Na','Nb','Nc','Nd','Nh','D','T','VA','VB','VC','VD','VE','VH','C']]
y=df[['Machine']]

x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=2019) 

x_train
Text Length Different Words Entropy Simpson Index Sentence Count Sentence Length Average Sentence Length Variance Function Word Count Function Word Proportion TTR ... Nh D T VA VB VC VD VE VH C
95 44 29 3.297037 0.040175 7 6.285714 215.428571 12 0.272727 0.659091 ... 3 1 0 0 0 0 0 0 1 2
47 13 10 2.138397 0.147929 9 1.444444 0.555556 2 0.153846 0.769231 ... 0 0 1 1 0 1 0 0 0 0
14 15 8 2.043192 0.135802 9 1.666667 0.333333 0 0.000000 0.533333 ... 0 3 0 3 0 1 0 0 0 0
81 12 9 2.197225 0.111111 2 6.000000 18.000000 0 0.000000 0.750000 ... 3 4 0 4 0 0 0 0 1 0
90 14 12 2.484907 0.083333 2 7.000000 8.000000 2 0.142857 0.857143 ... 1 0 0 0 0 0 0 0 1 0
2 51 22 2.655133 0.129291 25 2.040000 7.240000 2 0.039216 0.431373 ... 0 2 0 0 0 3 0 0 2 0
22 33 19 2.902002 0.057851 20 1.650000 12.550000 3 0.090909 0.575758 ... 1 3 1 0 0 2 0 0 3 0
77 44 38 3.612136 0.027960 2 22.000000 450.000000 7 0.159091 0.863636 ... 0 0 0 0 0 3 0 1 0 1
78 19 16 2.751667 0.065744 2 9.500000 12.500000 2 0.105263 0.842105 ... 2 5 0 0 0 2 0 0 1 0
41 12 6 1.747868 0.183673 7 1.714286 0.285714 2 0.166667 0.500000 ... 1 3 1 0 0 0 0 0 0 0
25 10 6 1.747868 0.183673 7 1.428571 0.571429 1 0.100000 0.600000 ... 1 0 1 0 0 0 0 0 2 0
92 37 34 3.526361 0.029412 2 18.500000 364.500000 7 0.189189 0.918919 ... 0 0 0 1 0 0 0 0 0 1
34 9 5 1.560710 0.222222 6 1.500000 0.500000 1 0.111111 0.555556 ... 1 1 1 1 0 2 0 0 0 0
73 11 9 2.197225 0.111111 2 5.500000 0.500000 3 0.272727 0.818182 ... 0 0 0 0 0 0 0 0 1 0
66 24 21 3.028029 0.049587 2 12.000000 50.000000 2 0.083333 0.875000 ... 4 2 1 0 0 2 0 1 0 0
74 20 11 2.351257 0.098765 2 10.000000 0.000000 6 0.300000 0.550000 ... 2 2 0 0 0 0 0 0 0 0
64 10 7 1.945910 0.142857 2 5.000000 2.000000 3 0.300000 0.700000 ... 1 0 2 0 0 0 0 0 1 0
56 32 22 3.044820 0.050296 6 5.333333 55.333333 4 0.125000 0.687500 ... 0 5 1 0 0 1 0 0 2 0
96 22 20 2.995732 0.050000 2 11.000000 18.000000 3 0.136364 0.909091 ... 2 1 1 0 0 0 0 0 0 0
38 10 9 2.197225 0.111111 9 1.111111 0.888889 2 0.200000 0.900000 ... 1 1 1 1 0 0 0 0 0 0
75 10 8 2.079442 0.125000 2 5.000000 2.000000 2 0.200000 0.800000 ... 2 1 1 0 0 0 0 1 0 0
76 14 10 2.253858 0.111111 2 7.000000 32.000000 2 0.142857 0.714286 ... 1 1 0 0 0 0 0 0 0 0
17 12 6 1.732868 0.187500 6 2.000000 2.000000 0 0.000000 0.500000 ... 0 0 0 0 0 2 0 0 0 0
79 19 14 2.599302 0.078125 2 9.500000 24.500000 3 0.157895 0.736842 ... 4 3 0 0 0 0 0 0 2 0
27 20 14 2.599302 0.078125 14 1.428571 5.428571 2 0.100000 0.700000 ... 1 2 1 0 0 0 0 0 0 0
26 18 9 2.069202 0.142857 14 1.285714 0.714286 3 0.166667 0.500000 ... 3 3 0 0 0 2 0 1 2 0
52 26 21 3.014947 0.051040 2 13.000000 98.000000 4 0.153846 0.807692 ... 3 1 0 0 0 0 0 2 0 1
3 12 8 2.043192 0.135802 7 1.714286 3.428571 0 0.000000 0.666667 ... 1 0 0 0 0 1 0 0 1 0
87 136 113 4.699810 0.009494 6 22.666667 1899.333333 17 0.125000 0.830882 ... 2 7 0 1 0 2 0 0 5 3
60 8 6 1.791759 0.166667 2 4.000000 8.000000 1 0.125000 0.750000 ... 1 1 1 0 0 0 0 0 1 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
53 89 69 4.130615 0.018685 2 44.500000 2664.500000 13 0.146067 0.775281 ... 1 5 0 5 0 3 0 1 1 1
33 48 20 2.422693 0.177285 36 1.333333 4.666667 5 0.104167 0.416667 ... 1 3 0 0 0 7 1 2 0 0
61 19 17 2.833213 0.058824 2 9.500000 12.500000 4 0.210526 0.894737 ... 1 2 1 0 0 0 0 0 2 0
51 11 9 2.197225 0.111111 2 5.500000 0.500000 1 0.090909 0.818182 ... 1 2 0 0 0 0 0 0 2 0
11 9 5 1.609438 0.200000 5 1.800000 2.800000 0 0.000000 0.555556 ... 1 2 0 0 0 2 0 2 0 0
10 8 4 1.329661 0.277778 6 1.333333 0.666667 0 0.000000 0.500000 ... 0 2 2 0 0 0 0 0 2 0
54 31 28 3.319493 0.036861 2 15.500000 24.500000 5 0.161290 0.903226 ... 1 1 0 1 0 1 0 0 0 0
65 14 10 2.302585 0.100000 4 3.500000 13.000000 0 0.000000 0.714286 ... 3 2 0 0 0 0 0 1 1 0
89 13 9 2.145842 0.123967 2 6.500000 0.500000 3 0.230769 0.692308 ... 1 0 1 0 0 0 0 0 0 0
86 11 7 1.889159 0.160494 2 5.500000 0.500000 4 0.363636 0.636364 ... 2 0 0 0 0 0 0 0 0 0
21 16 10 2.302585 0.100000 10 1.600000 0.400000 3 0.187500 0.625000 ... 1 3 1 0 0 0 0 0 3 0
28 7 2 0.636514 0.555556 4 1.750000 0.250000 0 0.000000 0.285714 ... 0 1 2 0 0 0 0 0 2 0
91 56 51 3.924584 0.019970 3 18.666667 340.666667 10 0.178571 0.910714 ... 1 1 0 0 0 1 0 0 2 3
5 15 8 2.043192 0.135802 9 1.666667 0.333333 2 0.133333 0.533333 ... 0 3 0 0 0 3 0 0 0 0
50 23 18 2.857103 0.060000 2 11.500000 84.500000 1 0.043478 0.782609 ... 2 2 0 0 0 0 0 0 1 0
80 29 25 3.204778 0.041420 3 9.666667 0.333333 2 0.068966 0.862069 ... 4 3 0 0 0 1 0 1 1 0
93 13 11 2.397895 0.090909 2 6.500000 24.500000 1 0.076923 0.846154 ... 1 0 0 0 0 0 0 0 0 0
83 24 21 3.028029 0.049587 2 12.000000 18.000000 4 0.166667 0.875000 ... 2 3 0 0 0 0 0 0 1 0
71 13 11 2.397895 0.090909 2 6.500000 0.500000 2 0.153846 0.846154 ... 1 0 0 0 0 1 0 0 0 0
48 8 4 1.386294 0.250000 4 2.000000 2.000000 4 0.500000 0.500000 ... 1 0 0 0 0 0 0 0 0 0
16 33 18 2.846480 0.061224 21 1.571429 0.428571 5 0.151515 0.545455 ... 0 4 0 0 0 0 0 0 1 2
12 11 6 1.732868 0.187500 8 1.375000 0.625000 1 0.090909 0.545455 ... 2 1 1 0 0 1 0 0 0 0
15 14 10 2.253858 0.111111 10 1.400000 4.400000 0 0.000000 0.714286 ... 2 3 0 1 0 1 0 2 0 0
29 34 23 3.084652 0.048469 26 1.307692 0.692308 4 0.117647 0.676471 ... 1 9 0 3 0 0 0 0 3 1
24 14 9 2.197225 0.111111 9 1.555556 0.444444 2 0.142857 0.642857 ... 1 0 0 0 0 2 0 1 0 0
62 26 23 3.135494 0.043478 2 13.000000 98.000000 0 0.000000 0.884615 ... 3 3 0 0 0 0 0 0 3 1
88 14 11 2.369382 0.097222 2 7.000000 50.000000 2 0.142857 0.785714 ... 1 1 0 0 0 0 0 0 0 0
37 37 19 2.743635 0.086420 21 1.761905 3.269841 7 0.189189 0.513514 ... 3 2 2 0 0 2 0 0 1 2
31 24 10 2.033000 0.172840 12 2.000000 0.000000 3 0.125000 0.416667 ... 2 1 0 2 0 0 0 0 1 1
72 19 17 2.833213 0.058824 2 9.500000 112.500000 2 0.105263 0.894737 ... 1 4 0 0 0 1 0 1 0 0

67 rows × 24 columns

標準化 :為了避免偏向某個變數去做訓練

from sklearn.preprocessing  import StandardScaler
sc=StandardScaler()

sc.fit(x_train)

x_train_nor=sc.transform(x_train)
x_test_nor=sc.transform(x_test)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/preprocessing/data.py:645: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by StandardScaler.
  return self.partial_fit(X, y)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:6: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by StandardScaler.

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:7: DataConversionWarning: Data with input dtype int64, float64 were all converted to float64 by StandardScaler.
  import sys

訓練資料分類效果(3個參數)

from sklearn.linear_model  import LogisticRegression
import math
lr=LogisticRegression()
lr.fit(x_train_nor,y_train)

# 印出係數
print(lr.coef_)
#印出24個檢定變數的顯著性,以 P-value 是否小於 0.05(信心水準 95%)來判定
print(f_regression(x_train_nor,y_train)[1])
# 印出截距
print(lr.intercept_ )

[[ 0.05073646  0.16110099  0.5169015  -0.71144527 -1.88462048  1.16090357
   0.0333939   0.24978403  0.19895375  0.952695   -0.25788532  0.26413467
  -0.34617137 -0.36122929  0.57772752 -0.00945377 -0.10280697 -0.28001234
   0.         -0.52968175 -0.06851578 -0.21853065  0.22126113 -0.07297395]]
[1.59650601e-01 6.90901540e-03 2.35938438e-05 9.17292701e-06
 4.23644597e-09 1.15729405e-08 7.59778174e-02 3.20619794e-02
 1.09218594e-01 3.39824140e-13 2.38475771e-01 1.16056518e-01
 8.96529128e-01 6.63765503e-01 7.02383648e-03 3.75252408e-01
 6.51157050e-02 6.93987446e-01            nan 3.33110705e-02
 2.40812687e-01 7.72496060e-01 8.17869677e-01 4.91588646e-01]
[0.55797716]


/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:433: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/utils/validation.py:761: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
  y = column_or_1d(y, warn=True)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/utils/validation.py:761: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
  y = column_or_1d(y, warn=True)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/feature_selection/univariate_selection.py:299: RuntimeWarning: invalid value encountered in true_divide
  corr /= X_norms
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py:877: RuntimeWarning: invalid value encountered in greater
  return (self.a < x) & (x < self.b)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py:877: RuntimeWarning: invalid value encountered in less
  return (self.a < x) & (x < self.b)
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py:1831: RuntimeWarning: invalid value encountered in less_equal
  cond2 = cond0 & (x <= self.a)

分別計算是人機對話的機率、不是人機對話的機率

np.round(lr.predict_proba(x_test_nor),3)
array([[0.004, 0.996],
       [0.927, 0.073],
       [0.967, 0.033],
       [0.96 , 0.04 ],
       [0.021, 0.979],
       [0.   , 1.   ],
       [0.9  , 0.1  ],
       [0.858, 0.142],
       [0.82 , 0.18 ],
       [0.992, 0.008],
       [0.984, 0.016],
       [0.464, 0.536],
       [0.97 , 0.03 ],
       [0.988, 0.012],
       [0.473, 0.527],
       [0.113, 0.887],
       [0.933, 0.067],
       [0.991, 0.009],
       [0.812, 0.188],
       [0.869, 0.131],
       [0.996, 0.004],
       [0.903, 0.097],
       [0.002, 0.998],
       [0.133, 0.867],
       [0.957, 0.043],
       [0.45 , 0.55 ],
       [0.203, 0.797],
       [0.002, 0.998],
       [0.992, 0.008],
       [0.85 , 0.15 ]])

模型績效: 評估分類模型的好壞–用混淆矩陣

PS: 要使用視覺化混淆矩陣要先執行以下的code (官網提供的)



def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
from sklearn.metrics import confusion_matrix
cnf=confusion_matrix(y_test, lr.predict(x_test_nor))
print('混淆矩陣:', cnf)
混淆矩陣: [[19  2]
 [ 0  9]]
import itertools
target_name=['yes','no']
plot_confusion_matrix(cnf,classes=target_name,title='confusion matrix')
plt.show()
Confusion matrix, without normalization
[[19  2]
 [ 0  9]]


/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/matplotlib/font_manager.py:1241: UserWarning: findfont: Font family ['SimHei'] not found. Falling back to DejaVu Sans.
  (prop.get_family(), self.defaultFamily[fontext]))

png

#準確(分類)率
Accuracy= (19+9)/(19+9+2+0)
print(Accuracy)
0.9333333333333333
#命中率
precision=9/11
print(precision)
0.8181818181818182
#覆蓋率或者靈敏度
recall=9/9
print(recall)
1.0
F1=2/2.22222222222
print(F1)
0.9000000000009

Reference * Python機器學習(scikit-learn) –Logistic Regression * 如何辨別機器學習模型的好壞?秒懂Confusion Matrix * 第 22 天機器學習(2)複迴歸與 Logistic 迴歸