双任务网络改写感触纪要:
1.论志同道合者的重要性。宝剑赠英雄,红粉送佳人。生命中三两同行人,足矣。
2.小人物,笨办法,大智慧。 照着葫芦画瓢。别人的网络可以运行的通,就仿照他们的编写格式,进行修改,适应自己的个性化任务。修改之后,要是报错,不要气馁。实在不行的情况下,找到报错关键地方,一句一句的采用调试方式或print(),接下来要的就是耐心和细心了,日拱一步,步步为营。
3.几多愁来几多喜!
4.感叹,人世唯艰,人间正道是沧桑。
# class_num = 31 # model m_models1 = {tmp:model.network_dict['AlexNetFc'](True) for tmp in dataset_train} ''' m_models = {tmp:nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 3)) for tmp in dataset_train} m_models2 = {tmp:nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 5)) for tmp in dataset_train if tmp == "target"} ''' m_models = {'source':nn.Sequential(m_models1['source'], nn.Linear(m_models1['source'].out_features, 2))} print(m_models) m_models2 = {'target':nn.Sequential(m_models1['target'], nn.Linear(m_models1['target'].out_features, 5))} print(m_models2) m_models.update(m_models2) print("dataset_train:", dataset_train) print( m_models) ''' for tmp in dataset_train: print(tmp) print(m_models[tmp]) print({tmp:m_models[tmp]}) m_models = {tmp:nn.Sequential(m_models[tmp['souce']], nn.Linear(m_models[tmp['souce']].out_features, 5))} m_models2 = {tmp:nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 5))} print("tmp", tmp) m_models = {tmp:nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 5))} if tmp == "source": m_models = {tmp:nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 5))} print('tmp', nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 2))) print("m_models:", m_models) else: print(dataset_train[tmp]) m_models2 = {tmp:nn.Sequential(m_models[tmp], nn.Linear(m_models[tmp].out_features, 5))} print(t.get('d')) m_models.update(m_models2) ''' # load pretrained parameters # for tmp in m_models: # m_models[tmp].load_state_dict(torch.load(os.path.join(module_path[tmp], 'best_model.pth'))) print('m_models:',m_models) print('m_models[source]:',m_models['source']) print('m_models[source][0]',m_models['source'][0]) base_cross_stitch_net = cross_stitch_network.CrossStitchNetwork(m_models['source'][0], m_models['target'][0]) print('base_cross_stitch_net:',base_cross_stitch_net)