Source code for topsbm.transformer
"""
This contains the hSBM topic modelling transformer, TopSBM
"""
# This file is part of TopSBM
# Copyright 2017-8, Martin Gerlach and the University of Sydney
#
# TopSBM is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TopSBM is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with TopSBM. If not, see <https://www.gnu.org/licenses/>.
import numpy as np
import graph_tool
from graph_tool import Graph
from graph_tool.inference import minimize_nested_blockmodel_dl
import scipy.sparse
from sklearn.base import BaseEstimator
from sklearn.utils import check_array, check_random_state
[docs]class TopSBM(BaseEstimator):
"""A Scikit-learn compatible transformer for hSBM topic models
Parameters
----------
n_init : int, default=1
Number of random initialisations to perform in order to avoid a local
minimum of MDL. The minimum MDL solution is chosen.
min_groups : int, default=None
The minimum number of word and docuent groups to infer. This is also a
lower bound on the number of topics.
max_groups : int, default=None
The maximum number of word and docuent groups to infer. This also an
upper bound on the number of topics.
weighted_edges : bool, default=True
When True, edges are weighted instead of adding duplicate edges.
random_state : None, int or np.random.RandomState
Controls randomization. See Scikit-learn's glossary.
Note that if this is set, the global random state of libcore will be
affected, and the global random state of numpy will be temporarily
affected.
Attributes
----------
graph_ : graph_tool.Graph
Bipartite graph between samples (the first `n_samples_` vertices) and
features (the remaining vertices)
state_
Inference state from graphtool
n_levels_ : int
The number of levels in the inferred hierarchy of groups.
groups_ : dict
Results of group membership from inference.
Key is an integer, indicating the level of grouping (starting from 0).
Value is a dict of information about the grouping which contains:
B_d : int
number of doc-groups
B_w : int
number of word-groups
p_tw_d : array of shape (B_w, d)
doc-topic mixtures:
prob of word-group tw in doc d P(tw | d)
p_td_d : array of shape (B_d, n_samples)
doc-group membership:
prob that doc-node d belongs to doc-group td: P(td | d)
p_tw_w : array of shape (B_w, n_features)
word-group-membership:
prob that word-node w belongs to word-group tw: P(tw | w)
p_w_tw : array of shape (n_features, B_w)
topic distribution:
prob of word w given topic tw P(w | tw)
Here "d"/document refers to samples; "w"/word refers to features.
mdl_
minimum description length of inferred state
n_features_ : int
n_samples_ : int
References
----------
Martin Gerlach, Tiago P. Peixoto, and Eduardo G. Altmann,
`“A network approach to topic models,”
<http://advances.sciencemag.org/content/4/7/eaaq1360>`_.
Science Advances (2018)
"""
def __init__(self, n_init=1, min_groups=None, max_groups=None,
weighted_edges=True, random_state=None):
self.n_init = n_init
self.min_groups = min_groups
self.max_groups = max_groups
self.weighted_edges = weighted_edges
self.n_init = n_init
self.random_state = random_state
def __make_graph(self, X):
# make a graph
g = Graph(directed=False)
# define node properties
# kind: docs - 0, words - 1
kind = g.vp["kind"] = g.new_vp("int")
if self.weighted_edges:
ecount = g.ep["count"] = g.new_ep("int")
# add all documents first
doc_vertices = [g.add_vertex() for _ in range(X.shape[0])]
word_vertices = [g.add_vertex() for _ in range(X.shape[1])]
# add all documents and words as nodes
# add all tokens as links
X = scipy.sparse.coo_matrix(X)
if not self.weighted_edges and X.dtype != int:
X_int = X.astype(int)
if not np.allclose(X.data, X_int.data):
raise ValueError('Data must be integer if '
'weighted_edges=False')
X = X_int
for row, col, count in zip(X.row, X.col, X.data):
doc_vert = doc_vertices[row]
kind[doc_vert] = 0
word_vert = word_vertices[col]
kind[word_vert] = 1
if self.weighted_edges:
e = g.add_edge(doc_vert, word_vert)
ecount[e] = count
else:
for n in range(count):
g.add_edge(doc_vert, word_vert)
return g
def __fit_hsbm(self):
clabel = self.graph_.vp['kind']
state_args = {'clabel': clabel, 'pclabel': clabel}
if "count" in self.graph_.ep:
state_args["eweight"] = self.graph_.ep.count
self.mdl_ = np.inf
for _ in range(self.n_init):
# the inference
state = minimize_nested_blockmodel_dl(self.graph_, deg_corr=True,
# overlap=overlap, # TODO:
# implement overlap
state_args=state_args,
B_min=self.min_groups,
B_max=self.max_groups)
mdl = state.entropy()
if mdl < self.mdl_:
self.state_ = state.copy()
self.mdl_ = mdl
del state
# collect group membership for each level in the hierarchy
n_levels = len(self.state_.levels)
if n_levels == 2:
# only trivial bipartite structure
self.groups_ = {0: self.__get_groups(level=0)}
else:
# omit trivial levels:
# - level=n_levels-1 (single group),
# - level=n_levels-2 (bipartite)
self.groups_ = {level: self.__get_groups(level=level)
for level in range(n_levels - 2)}
self.n_levels_ = len(self.groups_)
[docs] def fit(self, X, y=None):
"""Fit the hSBM topic model
Constructs a graph representation of X and infers clustering.
Parameters
----------
X : ndarray or sparse matrix of shape (n_samples, n_features)
Word frequencies for each document, represented as non-negative
integers.
y : ignored
Returns
-------
self
"""
self.fit_transform(X)
return self
def __get_groups(self, level=0):
'''extract group membership statistics from the inferred state.
return dict
'''
level_state = self.state_.project_level(level).copy(overlap=True)
level_state_edges = level_state.get_edge_blocks()
# count labeled half-edges, group-memberships
n_groups = level_state.B
n_wb = np.zeros((self.n_features_, n_groups))
n_db = np.zeros((self.n_samples_, n_groups))
n_dbw = np.zeros((self.n_samples_, n_groups))
for e in self.graph_.edges():
z1, z2 = level_state_edges[e]
v1 = int(e.source())
v2 = int(e.target())
n_db[v1, z1] += 1
n_dbw[v1, z2] += 1
n_wb[v2 - self.n_samples_, z2] += 1
# p_w = np.sum(n_wb,axis=1) / float(np.sum(n_wb))
n_db = n_db[:, np.any(n_db, axis=0)]
Bd = n_db.shape[1]
n_wb = n_wb[:, np.any(n_wb, axis=0)]
Bw = n_wb.shape[1]
n_dbw = n_dbw[:, np.any(n_dbw, axis=0)]
# group-membership distributions
p_tw_w = (n_wb / np.sum(n_wb, axis=1)[:, np.newaxis]).T
# group membership of each doc-node P(t_d | d)
p_td_d = (n_db / np.sum(n_db, axis=1)[:, np.newaxis]).T
# topic-distribution for words P(w | t_w)
p_w_tw = n_wb / np.sum(n_wb, axis=0)[np.newaxis, :]
# Mixture of word-groups into documetns P(t_w | d)
p_tw_d = (n_dbw / np.sum(n_dbw, axis=1)[:, np.newaxis]).T
result = {}
result['Bd'] = Bd
result['Bw'] = Bw
result['p_tw_w'] = p_tw_w
result['p_td_d'] = p_td_d
result['p_w_tw'] = p_w_tw
result['p_tw_d'] = p_tw_d
return result
[docs] def plot_graph(self, filename=None, n_edges=1000):
"""Plots arcs from documents to words coloured by inferred group
Parameters
----------
filename : str, optional
Path to write to (e.g. 'something.png').
Otherwise returns a displayable object.
n_edges : int
Size of subsample to plot (reducing memory requirements)
"""
self.state_.draw(layout='bipartite', output=filename,
subsample_edges=n_edges, hshortcuts=1, hide=0)