基于scikit-learn的文档分类:获取分类影响最大词的高效方法
获取Logistic Regression+TF-IDF分类器中影响新文档分类的关键Token
嘿,你已经用TF-IDF搭配Logistic Regression搭好了二分类器,还把模型存成了pickle用来预测新文档的分类概率,现在想找出新文档里对分类影响最大(系数最高)和最小(系数最低)的N个token对吧?刚好针对你用的sklearn v0.19版本,我整理了一套靠谱的方法,一步步来就行:
步骤1:从保存的Pipeline里拆出关键组件
你的模型是用Pipeline封装的,所以首先得把里面的TF-IDF向量器和Logistic Regression分类器取出来,这样才能拿到特征和系数:
text_model = pickle.load(open('text_model.pkl', 'rb')) # 提取TF-IDF向量器组件 tfidf_vect = text_model.named_steps['vect'] # 提取Logistic Regression分类器组件 lr_clf = text_model.named_steps['clf']
步骤2:把所有Token和对应的模型系数配对
Logistic Regression的coef_属性存着每个特征(也就是token)的权重,而TF-IDF向量器的get_feature_names()方法(sklearn 0.19版本刚好适用这个方法)能拿到所有token的列表,把它们配对成字典就方便后续查找了:
# 获取全局所有token的列表 feature_names = tfidf_vect.get_feature_names() # 二分类场景下,coef_是(1, 特征数)的数组,取第一个元素拿到一维系数数组 coef = lr_clf.coef_[0] # 把token和对应的系数绑定成字典 token_coef_map = dict(zip(feature_names, coef))
步骤3:找出新文档里实际出现的Token
我们只关心当前新文档里存在的token,所以用同一个TF-IDF向量器对新文档做转换,然后筛选出那些TF-IDF值不为0的token:
# 对新文档进行TF-IDF转换,得到稀疏矩阵 new_doc_tfidf = tfidf_vect.transform(new_document) # 获取稀疏矩阵中非零元素的列索引(对应token的位置) non_zero_indices = new_doc_tfidf.nonzero()[1] # 从全局token列表里取出新文档实际包含的token new_doc_tokens = [feature_names[idx] for idx in non_zero_indices]
步骤4:筛选排序,得到Top N关键Token
现在有了新文档的token列表和它们对应的模型系数,接下来就可以排序取最高和最低的N个了:
N = 5 # 这里可以改成你想要的数量 # 筛选出仅属于新文档的token及其系数 new_doc_token_coefs = {token: token_coef_map[token] for token in new_doc_tokens} # 按系数从高到低排序,取前N个(对分类正向影响最大的token) top_positive_tokens = sorted(new_doc_token_coefs.items(), key=lambda x: x[1], reverse=True)[:N] # 按系数从低到高排序,取前N个(对分类负向影响最大的token) top_negative_tokens = sorted(new_doc_token_coefs.items(), key=lambda x: x[1])[:N] # 输出结果 print(f"新文档中对分类影响最大的{N}个Token(系数最高):") for token, val in top_positive_tokens: print(f"{token}: {val:.4f}") print(f"\n新文档中对分类影响最大的{N}个Token(系数最低):") for token, val in top_negative_tokens: print(f"{token}: {val:.4f}")
额外说明
- 系数的正负对应着分类倾向:假设你的标签里
1代表A类,0代表B类,那么正系数的token会提升文档被分到A类的概率,负系数的token则会降低这个概率(更偏向B类)。 - 我们只筛选了新文档里实际出现的token,所以结果都是和当前文档强相关的,不是全局所有token里的极值,这样更贴合你的需求。
内容的提问来源于stack exchange,提问作者Eugenio




