Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions captcha_cnn_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: UTF-8 -*-
import torch.nn as nn
import captcha_setting
import torchvision.models as models
import torchvision

# CNN Model (2 conv layer)
class CNN(nn.Module):
Expand Down Expand Up @@ -41,3 +43,134 @@ def forward(self, x):
out = self.rfc(out)
return out



class RES18(nn.Module):
def __init__(self):
super(RES18, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet18(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class RES34(nn.Module):
def __init__(self):
super(RES34, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet34(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class RES50(nn.Module):
def __init__(self):
super(RES50, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet50(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class RES101(nn.Module):
def __init__(self):
super(RES101, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet101(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class RES152(nn.Module):
def __init__(self):
super(RES152, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.resnet152(pretrained=False)
self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class ALEXNET(nn.Module):
def __init__(self):
super(ALEXNET, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.alexnet(pretrained=False)
self.base.classifier[-1] = nn.Linear(4096, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class VGG11(nn.Module):
def __init__(self):
super(VGG11, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.vgg11(pretrained=False)
self.base.classifier[-1] = nn.Linear(4096, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class VGG13(nn.Module):
def __init__(self):
super(VGG13, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.vgg13(pretrained=False)
self.base.classifier[-1] = nn.Linear(4096, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.vgg16(pretrained=False)
self.base.classifier[-1] = nn.Linear(4096, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class VGG19(nn.Module):
def __init__(self):
super(VGG19, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.vgg19(pretrained=False)
self.base.classifier[-1] = nn.Linear(4096, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class SQUEEZENET(nn.Module):
def __init__(self):
super(SQUEEZENET, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.squeezenet1_0(pretrained=False)
self.base.classifier[-3] = nn.Linear(512, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class DENSE161(nn.Module):
def __init__(self):
super(DENSE161, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.densenet161(pretrained=False)
self.base.classifier = nn.Linear(self.base.classifier.in_features, self.num_cls)
def forward(self, x):
out = self.base(x)
return out

class MOBILENET(nn.Module):
def __init__(self):
super(MOBILENET, self).__init__()
self.num_cls = captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN
self.base = torchvision.models.mobilenet_v2(pretrained=False)
self.base.classifier = nn.Linear(self.base.last_channel, self.num_cls)
def forward(self, x):
out = self.base(x)
return out
2 changes: 1 addition & 1 deletion captcha_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from captcha_cnn_model import CNN

def main():
cnn = CNN()
cnn = RES101()
cnn.eval()
cnn.load_state_dict(torch.load('model.pkl'))
print("load cnn net.")
Expand Down
2 changes: 1 addition & 1 deletion captcha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import one_hot_encoding

def main():
cnn = CNN()
cnn = RES101()
cnn.eval()
cnn.load_state_dict(torch.load('model.pkl'))
print("load cnn net.")
Expand Down
4 changes: 2 additions & 2 deletions captcha_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import torch.nn as nn
from torch.autograd import Variable
import my_dataset
from captcha_cnn_model import CNN
from captcha_cnn_model import *

# Hyper Parameters
num_epochs = 30
batch_size = 100
learning_rate = 0.001

def main():
cnn = CNN()
cnn = RES101()
cnn.train()
print('init net')
criterion = nn.MultiLabelSoftMarginLoss()
Expand Down