Source code for ggfm.data.metapath
# import dgl
import torch
import random
from .utils import save_pkl_file, open_pkl_file
# from dgl.data.utils import save_graphs, load_graphs
[docs]def construct_graph(data_dir, graph, src_dst2edge_type):
r"""
Construct dgl.heterograph from ggfm.data.graph.
Parameters
----------
data_dir: str
Data directory for saving dgl.heterograph object, which is saved as data_dir/graph.bin.
graph: class:`ggfm.data.Graph`
Target graph.
src_dst2edge_type: dict
The edge types corresponding to (src, dst) types.
"""
edges = graph.edge_list
single_edges = {}
for target_type in edges:
for source_type in edges[target_type]:
for relation_type in edges[target_type][source_type]:
srcs, dsts = [], []
single_edges[(source_type, target_type, relation_type)] = [[], []]
for target_id in edges[target_type][source_type][relation_type]:
for source_id in edges[target_type][source_type][relation_type][target_id]:
srcs.append(source_id)
dsts.append(target_id)
single_edges[(source_type, target_type, relation_type)][0].extend(srcs)
single_edges[(source_type, target_type, relation_type)][1].extend(dsts)
merged_edges = {}
for key, values in single_edges.items():
src_type, dst_type = key[0], key[1]
srcs, dsts = values[0], values[1]
if src_type == "paper" and dst_type == "paper":
cur_relation = key[2]
if "rev" in cur_relation:
cur_edge_type = (src_type, "cited by", dst_type)
if cur_edge_type not in merged_edges:
merged_edges[cur_edge_type] = [[], []]
merged_edges[cur_edge_type][0].extend(srcs)
merged_edges[cur_edge_type][1].extend(dsts)
else:
cur_edge_type = (src_type, "cites", dst_type)
if cur_edge_type not in merged_edges:
merged_edges[cur_edge_type] = [[], []]
merged_edges[cur_edge_type][0].extend(srcs)
merged_edges[cur_edge_type][1].extend(dsts)
elif src_type == "field" and dst_type == "field":
cur_relation = key[2]
if "rev" in cur_relation: # contains
cur_edge_type = (src_type, "contains", dst_type)
if cur_edge_type not in merged_edges:
merged_edges[cur_edge_type] = [[], []]
merged_edges[cur_edge_type][0].extend(srcs)
merged_edges[cur_edge_type][1].extend(dsts)
else:
cur_edge_type = (src_type, "in", dst_type)
if cur_edge_type not in merged_edges:
merged_edges[cur_edge_type] = [[], []]
merged_edges[cur_edge_type][0].extend(srcs)
merged_edges[cur_edge_type][1].extend(dsts)
else:
cur_edge_type = (src_type, src_dst2edge_type[(src_type, dst_type)], dst_type)
if cur_edge_type not in merged_edges:
merged_edges[cur_edge_type] = [[], []]
merged_edges[cur_edge_type][0].extend(srcs)
merged_edges[cur_edge_type][1].extend(dsts)
for key, values in merged_edges.items():
merged_edges[key] = (torch.tensor(values[0]), torch.tensor(values[1]))
# g = dgl.heterograph(merged_edges)
# save_graphs(data_dir + "graph.bin", g)
# print("graph has been saved!")
return merged_edges
[docs]def construct_graph_node_name(data_dir, graph):
r"""
Construct graph_node_name.pkl for ggfm.data.graph.
Parameters
----------
data_dir: str
Data directory for saving dgl.heterograph object, which is saved as data_dir/graph_node_name.pkl.
graph: class:`ggfm.data.Graph`
Target graph.
"""
graph_node_name = {}
graph_node_type = graph.get_types()
for i in range(len(graph_node_type)):
attr = "name"
if graph_node_type[i] == "paper": attr = "title"
graph_node_name[graph_node_type[i]] = graph.node_feature[graph_node_type[i]][attr].tolist()
save_pkl_file(data_dir + "graph_node_name.pkl", graph_node_name)
[docs]def metapath_based_corpus_construction(g, graph_node_name, target_type, metapaths, relation, mid_types, labeled_node_idxs, k=2):
r"""
Metapath-based corpus construction.
Parameters
----------
g: dgl.heterograph
Target graph.
graph_node_name: dict
Node names for each node type.
target_type: str
Sampled node type.
metapaths: list
Sampled metapaths for each node type.
relation: list
Relations for each node type.
mid_types: list
Midtypes of each node type's metapaths.
labeled_node_idxs: list
Sampled node indexes for the target node type.
k: int, optional
Number of samples retained for each metapath sampling。
(default: :obj:`2`)
"""
# glist, label_dict = load_graphs(data_dir + "graph.bin")
# g = glist[0]
# graph_node_name = open_pkl_file(data_dir + 'graph_node_name.pkl')
metapath = metapaths[target_type]
relation = relation[target_type]
sampling_time = 10
num_nodes = len(labeled_node_idxs)
all_path_for_sampling_times = [[] for _ in range(num_nodes)]
print("---------------------------------------")
print(f"Sampling {target_type} type nodes...")
for p, path in enumerate(metapath):
path_for_sampling_times = [[] for _ in range(num_nodes)]
print(f"Sampling the {p}-th path...")
for st in range(sampling_time):
traces, types = dgl.sampling.random_walk(g=g, nodes=torch.tensor(labeled_node_idxs), metapath=path)
traces = traces.tolist()
length = len(traces)
print(f"Performing the {st}-th sampling...")
for i in range(length):
if i % 10000 == 0:
print(f"Sampled {i} nodes...")
path_i = []
if traces[i][1] != -1:
path_i.append(target_type.replace('_', ' '))
path_i.append(graph_node_name[target_type][traces[i][0]].replace('_', ' '))
path_i.append(relation[p][0])
mid_type = mid_types[target_type][p][0]
path_i.append(mid_type.replace('_', ' '))
path_i.append(graph_node_name[mid_type][traces[i][1]].replace('_', ' '))
path_i.append(relation[p][1])
if len(traces[i]) <= 3: # apa pa
path_i.append(target_type.replace('_', ' '))
path_i.append(graph_node_name[target_type][traces[i][2]].replace('_', ' '))
else: # apcpa pcpa
mid_type = mid_types[target_type][p][1]
path_i.append(mid_type.replace('_', ' '))
path_i.append(graph_node_name[mid_type][traces[i][2]].replace('_', ' '))
path_i.append(relation[p][2])
mid_type = mid_types[target_type][p][2]
path_i.append(mid_type.replace('_', ' '))
path_i.append(graph_node_name[mid_type][traces[i][3]].replace('_', ' '))
path_i.append(relation[p][3])
mid_type = target_type
path_i.append(mid_type.replace('_', ' '))
path_i.append(graph_node_name[mid_type][traces[i][4]].replace('_', ' '))
else:
path_i.append(target_type.replace('_', ' '))
path_i.append(graph_node_name[target_type][traces[i][0]].replace('_', ' '))
path_i.append("</s>")
path_i = " ".join(path_i)
path_for_sampling_times[i].append(path_i)
path_for_sampling_times = [list(set(item)) for item in path_for_sampling_times if item]
path_for_sampling_times = [random.sample(item, min(k, len(item))) for item in path_for_sampling_times if item]
for i, item in enumerate(path_for_sampling_times):
path_for_sampling_times[i] = " ".join(item)
all_path_for_sampling_times[i].append(path_for_sampling_times[i])
all_path_for_sampling_times = [list(set(item)) for item in all_path_for_sampling_times if item]
for i, item in enumerate(all_path_for_sampling_times):
all_path_for_sampling_times[i] = " ".join(item)
all_path_for_sampling_times[i].rstrip(' </s> ')
print(all_path_for_sampling_times[0])
print(f"length: {len(all_path_for_sampling_times)}")
return all_path_for_sampling_times