-
Notifications
You must be signed in to change notification settings - Fork 109
/
Copy pathdemo.py
81 lines (67 loc) · 2.42 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# coding:utf-8
'''
March 2019 by Chen Jun
https://github.com/chenjun2hao/Attention_ocr.pytorch
'''
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
from utils import alphabet
import models.crnn_lang as crnn
use_gpu = True
encoder_path = './expr/attentioncnn/encoder_5.pth'
# decoder_path = './expr/attentioncnn/decoder_5.pth'
img_path = './test_img/20441531_4212871437.jpg'
max_length = 15 # 最长字符串的长度
EOS_TOKEN = 1
nclass = len(alphabet) + 3
encoder = crnn.CNN(32, 1, 256) # 编码器
# decoder = crnn.decoder(256, nclass) # seq to seq的解码器, nclass在decoder中还加了2
decoder = crnn.decoderV2(256, nclass)
if encoder_path and decoder_path:
print('loading pretrained models ......')
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))
if torch.cuda.is_available() and use_gpu:
encoder = encoder.cuda()
decoder = decoder.cuda()
converter = utils.strLabelConverterForAttention(alphabet)
transformer = dataset.resizeNormalize((280, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available() and use_gpu:
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
encoder.eval()
decoder.eval()
encoder_out = encoder(image)
decoded_words = []
prob = 1.0
decoder_attentions = torch.zeros(max_length, 71)
decoder_input = torch.zeros(1).long() # 初始化decoder的开始,从0开始输出
decoder_hidden = decoder.initHidden(1)
if torch.cuda.is_available() and use_gpu:
decoder_input = decoder_input.cuda()
decoder_hidden = decoder_hidden.cuda()
loss = 0.0
# 预测的时候采用非强制策略,将前一次的输出,作为下一次的输入,直到标签为EOS_TOKEN时停止
for di in range(max_length): # 最大字符串的长度
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_out)
probs = torch.exp(decoder_output)
decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.data.topk(1)
ni = topi.squeeze(1)
decoder_input = ni
prob *= probs[:, ni]
if ni == EOS_TOKEN:
# decoded_words.append('<EOS>')
break
else:
decoded_words.append(converter.decode(ni))
words = ''.join(decoded_words)
prob = prob.item()
print('predict_str:%-20s => prob:%-20s' % (words, prob))