-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PPSCI Export&Infer No. 29】 add export and inference #793
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,13 @@ hydra: | |
- EVAL.pretrained_model_path_dict | ||
- EVAL.batch_size | ||
- EVAL.num_val_step | ||
- EXPORT.pretrained_model_name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EXPORT->INFER There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
- INFER.pretrained_model_path_dict | ||
- INFER.export_path | ||
- INFER.batch_size | ||
- mode | ||
- vol_coeff | ||
- log_freq | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
|
@@ -25,6 +30,7 @@ hydra: | |
mode: train # running mode: train/eval | ||
seed: 42 | ||
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
|
||
# set default cases parameters | ||
CASE_PARAM: [[Poisson, 5], [Poisson, 10], [Poisson, 30], [Uniform, null]] | ||
|
@@ -57,3 +63,28 @@ EVAL: | |
pretrained_model_path_dict: null # a dict: {casename1:path1, casename2:path2, casename3:path3, casename4:path4} | ||
num_val_step: 10 # the number of iteration for each evaluation case | ||
batch_size: 16 | ||
|
||
# inference settings | ||
INFER: | ||
pretrained_model_name: null # a string, indicating which model you want to export. Support [Uniform, Poisson5, Poisson10, Poisson30]. | ||
pretrained_model_path_dict: {'Uniform': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/uniform_pretrained.pdparams', 'Poisson5': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson5_pretrained.pdparams', 'Poisson10': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson10_pretrained.pdparams', 'Poisson30': 'https://paddle-org.bj.bcebos.com/paddlescience/models/topopt/poisson30_pretrained.pdparams'} | ||
export_path: ./inference/topopt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果在pretrained_model_path_dict里的模型是互相独立的,可以把export_path改为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
pdmodel_path: ${INFER.export_path}.pdmodel | ||
pdpiparams_path: ${INFER.export_path}.pdiparams | ||
device: gpu | ||
engine: native | ||
precision: fp32 | ||
onnx_path: null | ||
ir_optim: true | ||
min_subgraph_size: 30 | ||
gpu_mem: 4000 | ||
gpu_id: 0 | ||
max_batch_size: 1024 | ||
num_cpu_threads: 10 | ||
batch_size: 4 | ||
sampler_key: Fixed # a string, indicating the sampling method. Support [Fixed, Uniform, Poisson]. | ||
sampler_num: 8 # a integer number, indicating the sampling rate of the sampling method, supported when `sampler_key` is Fixed or Poisson. | ||
img_num: 4 | ||
res_img_figsize: null | ||
save_res_path: ./inference/predicted | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处的save_res_path也可以改为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
save_npy: false |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,13 +17,16 @@ | |||||||||
import functions as func_module | ||||||||||
import h5py | ||||||||||
import hydra | ||||||||||
import matplotlib.pyplot as plt | ||||||||||
import numpy as np | ||||||||||
import paddle | ||||||||||
from omegaconf import DictConfig | ||||||||||
from paddle import nn | ||||||||||
from paddle.static import InputSpec | ||||||||||
from topoptmodel import TopOptNN | ||||||||||
|
||||||||||
import ppsci | ||||||||||
from deploy.python_infer import pinn_predictor | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这几行import建议放在对应函数内部,否则会影响整体topopt的代码块行数定位 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||
from ppsci.utils import logger | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -120,7 +123,7 @@ def evaluate(cfg: DictConfig): | |||||||||
|
||||||||||
# fixed iteration stop times for evaluation | ||||||||||
iterations_stop_times = range(5, 85, 5) | ||||||||||
model = TopOptNN() | ||||||||||
model = TopOptNN(**cfg.MODEL) | ||||||||||
|
||||||||||
# evaluation for 4 cases | ||||||||||
acc_results_summary = {} | ||||||||||
|
@@ -317,14 +320,120 @@ def val_metric(output_dict, label_dict, weight_dict=None): | |||||||||
return {"Binary_Acc": acc, "IoU": iou} | ||||||||||
|
||||||||||
|
||||||||||
# export model | ||||||||||
def export(cfg: DictConfig): | ||||||||||
# set model | ||||||||||
model = TopOptNN(**cfg.MODEL) | ||||||||||
|
||||||||||
# initialize solver | ||||||||||
solver = ppsci.solver.Solver( | ||||||||||
model, | ||||||||||
eval_with_no_grad=True, | ||||||||||
pretrained_model_path=cfg.INFER.pretrained_model_path_dict[ | ||||||||||
cfg.EXPORT.pretrained_model_name | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EXPORT->INFER There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||
], | ||||||||||
) | ||||||||||
|
||||||||||
# export model | ||||||||||
input_spec = [{"input": InputSpec([None, 2, 40, 40], "float32", name="input")}] | ||||||||||
|
||||||||||
solver.export(input_spec, cfg.INFER.export_path) | ||||||||||
|
||||||||||
|
||||||||||
# model inference | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这行注释可以删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除 |
||||||||||
def inference(cfg: DictConfig): | ||||||||||
# read h5 data | ||||||||||
h5data = h5py.File(cfg.DATA_PATH, "r") | ||||||||||
data_iters = np.array(h5data["iters"]) | ||||||||||
data_targets = np.array(h5data["targets"]) | ||||||||||
idx = np.random.choice(len(data_iters), cfg.INFER.img_num, False) | ||||||||||
data_iters = data_iters[idx] | ||||||||||
data_targets = data_targets[idx] | ||||||||||
|
||||||||||
sampler = func_module.generate_sampler(cfg.INFER.sampler_key, cfg.INFER.sampler_num) | ||||||||||
data_iters = channel_sampling(sampler, data_iters) | ||||||||||
|
||||||||||
predictor = pinn_predictor.PINNPredictor(cfg) | ||||||||||
|
||||||||||
input_dict = {"input": data_iters} | ||||||||||
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size) | ||||||||||
|
||||||||||
# mapping data to output_key | ||||||||||
output_dict = { | ||||||||||
store_key: output_dict[infer_key] | ||||||||||
for store_key, infer_key in zip({"output"}, output_dict.keys()) | ||||||||||
} | ||||||||||
|
||||||||||
save_topopt_img( | ||||||||||
input_dict, | ||||||||||
output_dict, | ||||||||||
data_iters, | ||||||||||
cfg.INFER.save_res_path, | ||||||||||
cfg.INFER.res_img_figsize, | ||||||||||
cfg.INFER.save_npy, | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
# used for inference | ||||||||||
def channel_sampling(sampler, input): | ||||||||||
SIMP_initial_iter_time = sampler() | ||||||||||
input_channel_k = input[:, SIMP_initial_iter_time, :, :] | ||||||||||
input_channel_k_minus_1 = input[:, SIMP_initial_iter_time - 1, :, :] | ||||||||||
input = np.stack( | ||||||||||
(input_channel_k, input_channel_k - input_channel_k_minus_1), axis=1 | ||||||||||
) | ||||||||||
return input | ||||||||||
|
||||||||||
|
||||||||||
# used for inference | ||||||||||
def save_topopt_img( | ||||||||||
input_dict, output_dict, ground_truth, res_path, figsize=None, npy=False | ||||||||||
): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||
|
||||||||||
input = input_dict["input"] | ||||||||||
output = output_dict["output"] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 保存前先创建文件夹,否则会报错
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||
for i in range(len(input)): | ||||||||||
plt.figure(figsize=figsize) | ||||||||||
plt.subplot(1, 4, 1) | ||||||||||
plt.axis("off") | ||||||||||
plt.imshow(input[i][0], cmap="gray") | ||||||||||
plt.title("Input Image") | ||||||||||
plt.subplot(1, 4, 2) | ||||||||||
plt.axis("off") | ||||||||||
plt.imshow(input[i][1], cmap="gray") | ||||||||||
plt.title("Input Gradient") | ||||||||||
plt.subplot(1, 4, 3) | ||||||||||
plt.axis("off") | ||||||||||
plt.imshow(np.round(output[i][0]), cmap="gray") | ||||||||||
print(output[i]) | ||||||||||
plt.title("Prediction") | ||||||||||
plt.subplot(1, 4, 4) | ||||||||||
plt.axis("off") | ||||||||||
plt.imshow(np.round(ground_truth[i][0]), cmap="gray") | ||||||||||
print(ground_truth[i]) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 两处print可以删除,否则会打印大量矩阵元素,如果是为了友好提示可以在for循环的最后一行使用logger.message信息打印保存成功的提示 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 忘记删掉了(✿◡‿◡),已修改 |
||||||||||
plt.title("Ground Truth") | ||||||||||
plt.show() | ||||||||||
plt.savefig(osp.join(res_path, "Prediction_" + str(i) + ".png")) | ||||||||||
plt.close() | ||||||||||
if npy: | ||||||||||
with open(osp(res_path, "Prediction_" + str(i) + ".npy"), "wb") as f: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 两处字符串拼接改为f-string: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||||||||||
np.save(f, output[i]) | ||||||||||
|
||||||||||
|
||||||||||
@hydra.main(version_base=None, config_path="./conf", config_name="topopt.yaml") | ||||||||||
def main(cfg: DictConfig): | ||||||||||
if cfg.mode == "train": | ||||||||||
train(cfg) | ||||||||||
elif cfg.mode == "eval": | ||||||||||
evaluate(cfg) | ||||||||||
elif cfg.mode == "export": | ||||||||||
export(cfg) | ||||||||||
elif cfg.mode == "infer": | ||||||||||
inference(cfg) | ||||||||||
else: | ||||||||||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||||||||||
raise ValueError( | ||||||||||
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" | ||||||||||
) | ||||||||||
|
||||||||||
|
||||||||||
if __name__ == "__main__": | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否可以以注释的形式补充另外几个Poisson系列的模型呢?模型推理命令同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改