--- welcome to longking's layer ---
home m&g spirit project about
All the image textures are all from unsplash©
Hits:

Welcome
写在前面

旨在分析一个普通交叉学科学生的代码日常。

所有的分享都是基于实际自己做过的小案例, 毕竟文档大家都会看不是。

日期均为条目创建时间。




一直以来, 陆陆续续都有身边的朋友问我: 为什么要学写代码, 你是怎么学代码的(我不是科班背景, 甚至不是机械电气生信这类一般需要用到编程的专业), 该怎么学或者说该学哪些内容。 那我就放在前页这里吧, 算是这几年来的一个小总结, 也咨询了身边的科班大佬。以下全部为个人意见, 不代表完全可靠。


Why


  • 个人爱好, 系统学习之前有一点基础算是。
  • 剑走偏锋, 开辟新赛道, 未来更加灵活吧。
  • Coding对我来说, 不管是科研上工作上, 生活上社交上自己写点东西, 帮助他人也帮助自己。


What


当你要去选择一个东西时, 你要想好你希望要从这个东西上获取什么

我记得我跟别人讲过, 如果你升学进学校没有基础, 但是想学Coding的话, 首先要想好自己想用这个技能来做什么东西。大概总结了以下几种情况吧

  • 想完全转行 ==> 这种情况我反而不推荐Python。PY的优点是学起来简单容易上手; 缺点也很明显, 第一慢, 第二纯Python岗位例如算法研发岗, C9的科班CS硕都不一定能招架。 而且单论工作的话明显C++ / Java / JS甚至GoLang的工作是最多的也最普适的。 所以这个时候不如去学一门JS或者JAVA, 等你还要学PY的时候也会简单很多毕竟从难倒易很轻松, 反过来就不一定了。
  • 想半转(交叉) ==> Python(或者R语言, 对于某些专业来说)作为ML/DL的核心语言肯定是你的不二之选。这个时候我不推荐你学Matlab, 因为既然你是半转行, 那么Matlab的局限就很大了。 当然了, 等你做到部署的时候你还是要学C++, 但是咋说先把模型做出来不是。况且很多时候不需要部署。另外Java或者C++也是可以的, 比如专业性很强的软件公司, 业内企业的信息系统。
  • 只是为了搞科研 ==> Python/R/Matlab等均可, 看你的专业需求。


How


首先, 你得要会基础才能写, Hello World还有很多种写法不是。我个人的习惯是偏向视频(1.5倍速)学习, 书籍为辅。当然学习的方法仁者见仁, 适合自己的才是最好的。 对于视频方面, 比利比利大学作为国内一流大学, 每年都会输出很多很好的讲座/教学视频(bushi), 到上面找着看就行, 油管也可以当然需要你的英语水平达标了。一定要找那种讲的详细的, 多看几遍不要紧的。 不过要注意一点, 因为这类视频基本上都是面向就业的, 所以他们都会涉及到很多工作上的用途内容, 比如Linux操作/多并发等, 这些内容在你学习基础的时候是可以直接不看的。因此打个比方对于一个700集的视频来说, 基本上只用看1/3就可以了。 效率为王。买课的话, 我觉得没必要, 因为哪怕是身边科班的同学, 他们告诉我的情况是即便是科班甚至TOP科班, 大部分的Code Tricks也是非课堂时间学到的, 而且有的卖课的真的很水。

这里我想解释一下, 举个小例子。对于计算机专业或者机械电气信息等的同学来说, Linux很重要。事实上它们对我也很重要因为我的模型是在服务器上跑的, 操作系统是CentOS。 但是对于一些传统专业, 甚至都没有类似的编程环境和设备供使用, 与其去学一些如果不转行甚至都用不上的东西, 不如好好研究一下自己领域的东西不是吗。

Coding这个东西, 跟数学很像其实: 看着很简单, 一看就感觉会了, 一写就直接寄了, 一定要多练。这个时候有人就会问我怎么练。大概有两种渠道吧:

  • 刷题网站 ==> Leetcode / pycheckio 等
  • 在Github / CSDN 等上找案例(或者干脆自己想一个)自己练, 比如你在学PyQt5, 那么你就可以去自己写一个小计算器, 甚至一个文件管理系统等。 要是ML/DL的话就更好说了, 国外去Kaggle / UCIMLR / AWS, 国内去天池 / PaddlePaddle 等直接找数据集, 学习别人优秀的代码怎么写的。 记得我以前审过一篇某SCI Q1的文章, 交叉学科内容的, 作者很贴心附上了源码, 不过吧那个源码写的真的是一言难尽。后来把它拒了, 不过主要原因还是他的Method有问题跟代码没啥关系。 这里我推荐一个清华毕业的微软大佬讲的Coding过程中命名方式的基本准则, 毕竟代码的可读性也很重要。链接这里

然后我个人是极其推荐各位去刷刷题练练手, 看看别人代码的。很多方法可能你都不知道以至于花费了很多的时间去做无用功, 另外Coding是一件多人协作的东西, 你的代码质量关乎你的后期维护。 尤其对于交叉学科的同学来说, 看文档本身就是一件难以在短时间内习惯的事情。通过阅读别人的代码, 不仅能增加自己代码的编写水平, 掌握相应的Coding技巧, 还能了解到很多很好用的工具和方法。 任何语言也适用。


第二, 由于现在ML/DL的火热, 估计很多人在掌握好基础后都会想到往算法这边靠一点, 至少对研究有很大的好处。

基础当然是吴恩达老师的机器学习基础, 比利比利上有同源视频。虽然只用的话一个Sklearn / Pytorch 等Import就能解决, 但是再怎么说, 原理你得知道吧, 写代码虽说不一定什么时候都有造轮子的能力不能老当调包侠。 之后我是看的李航的数学统计方法, 深入了解这些ML Method的原理。之后又看了一些东西, 包括这些库怎么用, 若干案例分享七七八八的, 然后就去Kaggle上做练习去了。

DL方面除了上面所提到的一些东西, arkiv上的论文, 比利比利上的各位大佬的精读讲解和经验分享(Mu Li is all your need)甚至知乎上的大佬(知乎计算机区还是很可靠的)都是可以学习的。 我认为CS领域给我留下最深刻的一点印象就是, 虽然真的很卷全是怪物, 但是很开放, 大家都愿意把自己写的代码放出来给别人参考或者指正。或许这也是它卷的根源之一, 但是对于小白来说真的极其友好。

之后因为要做部署, 就是顺带学了一下C++。从简单到困难真的很不容易, 一个指针概念我都看了三天还专门上油管看的印度人的视频, 劝大家学有余力的话都去试一下。

顺带一提, 当你在Coding上遇到Bug或者困难的时候, 多去查一查, Github / Stackoverflow / CSDN / 甚至专属的论坛, 你的错误可能别人早都发现并且解决了。 Coding有时候就像炼丹, Bug千奇百怪, 问一两个人真不一定能得到解决结果, 需要自己具有找寻答案的能力。


此外就是我个人学的一些其他东西了比如 MYSQL / HTML / JS / Java, 包括这些大类下的各类分支框架和要点。这么多语言这么多框架, 前端和后端这么多东西, 还是多学一点好, 也是一个完整项目的必要条件不是吗。

2022-08-31
Instance Segmentation 调参经验

最近做的一个实例分割的任务, 先上调参前后的图片对比, 模型在经过多重比对后均使用SOLOV2, backbone为ResNet-18, 没有大量魔改模型(比如加Attn等)

差的还是挺明显的。

数据集概况: 三类, 其中第一类为主要类别, 后两类是我希望让网络学到的废类别(相当于一种先验知识)。不然是很容易将这三类互相搞混的(实验得到)。三类数目分别为4700 3600 300, 每张图片大概有100-200个目标(三类加在一起)。图片大小为1200*800的双通道图(Hik原本拍出来的是3072*2048)。

设备概况: 一张NVIDIA-V100 32G, Python 3.9, Torch 1.11, CUDA 11.3, 基本上Batch只能有4, 训练500Epoch需要16h左右。顺带一提我这边的实验前40个Epoch AP稳定为0, 当时吓得以为我哪里整错了。


调参嘛, 无非就是各种loss, lr, 以及结构参数之类的这些东西

首先, 这里的类别Loss选择的是Focal Loss。 由于FL主要是为了解决类别不平衡的问题, 其中alpha控制类别不平衡即alpha越小表示负样本越多; gamma控制难易样本的损失降低简单样本的损失值。 而这里如果把它肯成一个二分类问题的话, 其实还算比较平衡的, 因此将默认的alpha=0.25设置成了0.4-0.45。 分割Loss用的是DiceLoss, 就用的原始的参数, 还行其实。

其次, 优化器选择的是原始的SGD(还蛮难下去的, 基本上就是前大后小), 初始学习率为0.05, 衰减方法选择的是StepLR。 由于我这里的BS和GPU个数较少, 一般所推荐的初始0.1或者0.01的学习率需要降低后再进行使用。 这里使用的是线性缩放原理, 参考文献为Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour

第三, 预处理部分。这里的第三类(第二个坏类别)的划分依据是由于图片的尺寸限制, 使得边缘的物体被切割而成。因此取消了所有的裁剪类型的数据增强, 包括mixup, clip等。 同时为了弥补这方面的损失, 提高了Flip, Resize, ColorJitter等方法的比例. 另外Simple Copy Paste在该任务上效果较好。RandAugment暂时没有用, 不知道效果如何。

同时, 以上调参大大增强了小目标的检测能力, 比Backbone为ResNet-50的原始SoloV2模型mAPs高了将近10倍, 蛮离谱的。

2022-09-02
Selenium 爬虫框架

起因来源于实习中的一个小任务, 平常是基本不做爬虫的

既然提到了Selenium这个框架, 那肯定是基本的url.request无法满足这里的需求


在依靠ajax等框架加载的网页中, 由于其异步问题, 网页中的所有内容是不会一下子全部加载出来的需要靠脚本后端往服务器里要数据, 通过url.request打印的信息也是如此只能打印一部分静态网页和大量的JS脚本 而其实远远不止这些。具体情况如下图所示:

一般而言, url.request需要配合正则匹配来进行对应元素目标位置的锁定。但是既然出不来那也就无从谈起。

因此, 我这里选择的方法是。采用Selenium通过搜索整个html code里的元素path, 来达到对应图片的爬取。 此外由于该网页存在下拉加载, 因此需要同时进行自动的下拉操作。

Selenium框架会生成一个对应的浏览器窗口模拟行为进行网页的浏览和抓取

driver = webdriver.Chrome(options=hide())
driver.maximize_window()
driver.get(url) # url为对应的域名
                        

这里的options调整了两个params, 都算是Chrome官方推荐的

chrome_options = Options()
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--hide-scrollbars')
                        

之后需要实现下拉功能, 一段简单的js代码:

# 大概要下拉多少次, 300为在网页上所呈现的图片小图的高度估值, 可小不可太大。
for i in range(1, 10000): 
    if (i - 1) % 4 == 0:
        dis = 300 * ((i - 1) // 4)
        js = "var q=document.documentElement.scrollTop={}".format(dis)
        driver.execute_script(js)

Selenium给我们提供了很多找元素的参照, 包括class / id / xpath。具体路径可在devtools里面找到, 对应元素右键即可。 重要的是找到你想要爬下来的这一批元素的共同点, 比如我这里是用Xpath路径, 只有一个地方不一样且是按规律变化的:

xpath = '//*[@id="page_index"]/div/div[1]/div[2]/div[2]/div/div/div[{}]/img'.format(i)
elements = driver.find_elements_by_xpath(xpath)

之后就是拿图片的地址并保存了。众所周知img标签下的地址为src。

for ele in elements:
    name = ele.text
    img_url = ele.get_attribute('src')
    print(i, name, img_url)
    if img_url:
        request.urlretrieve(img_url, f'./{title}/{i}.jpg')
        suc += 1
        print('-------------------------------')
    else:
        fal += 1
        print('----  this process errored, success = {}, failed = {}  ----'.format(suc, fal))
        break

整体代码

from urllib import request
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
import time

def hide():
    chrome_options = Options()
    chrome_options.add_argument('--disable-gpu')
    chrome_options.add_argument('--hide-scrollbars')
    return chrome_options

def process(url, title):
    driver = webdriver.Chrome(options=hide())
    driver.maximize_window()
    driver.get(url)

    suc, fal = 0, 0

    for i in range(1, 10000):
        if (i - 1) % 4 == 0:
            dis = 300 * ((i - 1) // 4)
            js = "var q=document.documentElement.scrollTop={}".format(dis)
            driver.execute_script(js)

        xpath = '//*[@id="page_index"]/div/div[1]/div[2]/div[2]/div/div/div[{}]/img'.format(i)
        elements = driver.find_elements_by_xpath(xpath)
        if not elements:
            break
        time.sleep(0.2)
        for ele in elements:
            name = ele.text
            img_url = ele.get_attribute('src')
            print(i, name, img_url)
            if img_url:
                request.urlretrieve(img_url, f'./{title}/{i}.jpg')
                suc += 1
                print('-------------------------------')
            else:
                fal += 1
                print('----  this process errored, success = {}, failed = {}  ----'.format(suc, fal))
                break

def main():
    url2title = ['https://111.com', '1'],
                ['https://222.com', '2'],
                ['https://333.com', '3'],
                ['https://444.com', '4']]
    process()

main()

然后是一个同学委托的企查查的网页数据爬取, 从一个EXCEL表格读取一系列指定的企业名称并获得三项信息

具体代码和注释见下, 值得注意的是这类需要登录的网站, 需要指定个人保存登录信息的文件夹, 并事先给予自己一定的前置时间。 另外整体操作不能太快, 例如爬取一家公司的信息后间隔一会儿再进行下一个公司信息的爬取进程, 否则容易被限制IP (企查查可以通过购买会员的方式解决, 但别的网站就不一定了)。

from selenium import webdriver
from selenium.webdriver.common.by import By
import pandas as pd
import numpy as np
import re
import time


def opt():
    '''
        Chrome浏览器作为爬虫主体, 大概设置
    '''

    option = webdriver.ChromeOptions()
    option.add_experimental_option('excludeSwitches', ['enable-automation'])
    option.add_experimental_option('useAutomationExtension', False)
    option.add_argument(
        '--user-agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.146 Safari/537.36"')  # user-agent请求头伪装
    option.add_argument(
        r"user-data-dir=C:\Users\Longking\AppData\Local\Google\Chrome\User\User Data")  # 登录信息保存默认文件夹
    return option


def get_page_html(firm_name, driver, interval=0.3):
    inputBox = driver.find_element(By.CSS_SELECTOR, '#searchKey')  # 获取搜索栏
    searchBtn = driver.find_element(By.CSS_SELECTOR, '.input-group-btn>.btn-primary')  # 获取搜索按钮

    time.sleep(interval)  # 等待-输入-等待-点击-跳转
    inputBox.clear()
    inputBox.send_keys(firm_name)
    time.sleep(interval)
    driver.execute_script('arguments[0].click()', searchBtn)

    try:
        # 万一没有这家公司就直接跳过
        css_txt = "body>div>div.app-search>div.container.m-t>div.adsearch-list>div>div.msearch.select-search-enable>div>table>tr>td:nth-child(3)>div>div.app-copy-box.copy-hover-item.copy-part>span.copy-title"  # 获取点击的超链接
        itemLink = driver.find_element(By.CSS_SELECTOR, css_txt)
        time.sleep(interval)
        itemLink.click()

        try:
            # 万一没有注销时间直接跳过
            all_handle = driver.window_handles
            driver.switch_to.window(all_handle[-1])  # 将句柄置于当前窗口

            css_txt = ".bar-click"  # 获取位置元素
            try:

                zhuxiao_time = driver.find_element(By.CSS_SELECTOR, css_txt)
                zhuxiao = zhuxiao_time.text
                zhuxiao = re.findall(r'[0-9]+-[0-9]+-[0-9]+', zhuxiao)[0]  # 剪掉多余信息
            except:
                zhuxiao = 'none'
            time.sleep(interval)

            css_txt = "#cominfo>div.cominfo-normal>table>tr:nth-child(5)>td:nth-child(4)"
            try:
                yingye_qixian = driver.find_element(By.CSS_SELECTOR, css_txt)
                yingye = yingye_qixian.text
            except:
                yingye = 'none'
            time.sleep(interval)

            css_txt = "#cominfo>div.cominfo-normal>table>tr:nth-child(6)>td:nth-child(6)>div>span.copy-value"
            try:
                hezhun_qixian = driver.find_element(By.CSS_SELECTOR, css_txt)
                hezhun = hezhun_qixian.text
            except:
                hezhun = 'none'
            time.sleep(interval)

            return [zhuxiao, yingye, hezhun], True
        except:
            return ['level2', 'level2', 'level2'], False
    except:
        return ['leve1', 'leve1', 'leve1'], False


def main(login_prepare=5, change=10):
    """ 以获取注销时间为案例, login_prepare为预留的手动登录的时间, change为切换公司的等待时间, 单位为秒 """

    data = pd.read_excel('output.xlsx')  # 表格第一行不是表头, 所以跳过从第二行开始读

    total = data.shape[0]  # 总行数
    zhuxiao_list = [''] * total  # 列表程度保持和总行数相同
    yingye_list = [''] * total
    hezhun_list = [''] * total

    if '注销时间' in data.columns:
        zhuxiao_list = list(data['注销时间'])
        yingye_list = list(data['营业期限'])
        hezhun_list = list(data['核准日期'])

        for i in range(total):
            if type(zhuxiao_list[i]) == float:
                idx = i
                print(idx)
                break
    else:
        idx = 0

    driver = webdriver.Chrome(options=opt())  # 浏览器选项加入
    driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
        "source": """
        Object.defineProperty(navigator, 'webdriver', {
            get: () => undefined
        })
        """
    })  # 防止被检测
    driver.maximize_window()  # 窗口最大化, 小了的话可能拿不到元素
    driver.get('https://www.qichacha.com/user_login')  # 从登录开始
    time.sleep(login_prepare)  # 等待手动登录时间

    while idx < total:  # 获取每一个公司名称
        firm = data['企业名称'][idx]
        print(f'---- Now ==> {firm}')
        [zx, yy, hz], condition = get_page_html(firm, driver)  # 主函数
        print([zx, yy, hz])  # 没有填None, 有的话就填原来的。
        if condition:
            zhuxiao_list[idx] = zx
            yingye_list[idx] = yy
            hezhun_list[idx] = hz
        else:
            zhuxiao_list[idx] = 'none'
            yingye_list[idx] = 'none'
            hezhun_list[idx] = 'none'

        memory = idx + 1

        if memory % 10 == 0:  # 每10个公司保存一次, 可自己更改
            data_copy = data.copy()
            data_copy.loc[:, '注销时间'] = zhuxiao_list
            data_copy.loc[:, '营业期限'] = yingye_list
            data_copy.loc[:, '核准日期'] = hezhun_list
            data_copy.to_excel('output.xlsx', index=None)

        idx += 1
        time.sleep(change)

    data_copy = data.copy()
    data_copy.loc[:, '注销时间'] = zhuxiao_list
    data_copy.loc[:, '营业期限'] = yingye_list
    data_copy.loc[:, '核准日期'] = hezhun_list
    data_copy.to_excel('output.xlsx', index=None)


main()
                    
2022-09-04
Github-Pages Tips

该条目长期更新, 为个人使用Github服务器挂靠网页时所遇到的问题记录, 尽量可以帮到希望使用Github-Pages作为自己的第一个个人主页的朋友。

Github-Pages整体用下来的感觉是: 很方便只需要写完后Push就行了; 不收你一分钱; 不会HTML的话也有基于Markdown的模板不过Theme有限。 缺点是国内网络在使用(push)的时候可能会受到客观的阻碍(建议使用全局模式), 访问是没问题的; 另外这是一个纯前端网页, 毕竟人家不会让你从服务器上拿数据, 你的表单也不知道提交到哪里。

关于Github怎么注册, 怎么申请Page的Repo, 怎么push你的project, 这些步骤很简单搜一下就有了。好像Gitee也支持Pages功能了我没试过。


1. index.html

众所周知, 每一个网站都是有个主页的。而Github-Pages的主页的网页名称必须为index, 且必须放在repo的最上层目录。如果你是用Markdown使用模板作为语言编写, 则还需要带有config.yaml(这个东西创建完主题会自带), 这个yaml文件里保存了你所使用的主题, 网页的名称, header显示的内容等。

另外提一嘴, 如果你从没有过代码项目的经历, 请记住所有的路径引用必须为相对路径, 毕竟你也不知道github给你的绝对路径是什么。网站代码开源可以去我的Github Repo里面直接查看。


2. Error with Permissions-Policy header

完整报错命令为Error with Permissions-Policy header: Origin trial controlled feature not enabled: 'interest-cohort'. 可以在DevTools中查看。我这边发现此错误的原因是使用了JQuery官网下载下来的脚本库, 在Chrome浏览器上触发了谷歌的同类群组联合学习技术: Federated Learning of Cohorts(edge浏览器正常运行), 而Github也是反该项技术可能会抓取用户隐私的成员之一, 具体解释可查看这里。解决方法是使用CDN分发, 我这边用的是微软的分发。

2022-11-26
如何规避Pyinstaller存在的种种打包Bug

众所周知, 在Py上打包一些较为简易的程序, 大多数还是会用Pyinstaller。但是吧代码量小还好, 代码量大的话经常会出现这个没打包进去, 那个DDL缺失的情况。 最近各类大大小小的软件也做了不少, 总结了一下值得注意的东西。


怎么使用就不介绍了, 我一般不会把其打包成一个总的, 运行速度太慢, 即加上参数 -D -w。 有人会问如果我的程序打包存在问题, 但我没法儿发现怎么办。这里我所知道的大概有三种解决方法:

  1. Python traceback
  2. 比如你的代码里, 某一个类是主类然后主要也靠它调用东西, 那么你就可以用traceback去掌握报错信息。 此方法也特别适合你懒得写那么多的专门的ValueError的提示框和判断, 也适合打包程序后的Debug处理。

    class MainClass:
        ...
        ...
        def start(self):
            ...
                                    

    这个时候就很适合去加一个hook并回溯报错了。具体是这样的(PyQt5的程序)

    import sys
    
    class MainClass:
        ...
        def __init__(self):
            self.old_hook = sys.excepthook
            sys.excepthook = self._catch_error
    
        def _catch_error(self, ty, value, tb):
            traceback_format = traceback.format_exception(ty, value, tb)
            traceback_string = "".join(traceback_format)
            QtWidgets.QMessageBox.critical(self, "额外报错, ", "{}".format(traceback_string))
            self.old_hook(ty, value, tb)
                                    

    这样的话不管是什么稀奇古怪的报错, 都可以起到提示的作用, 你在IDE里看到的是啥报错这里就给你弹的是什么。可以掌管任何在此class里引用出现的函数变量和类。

  3. 命令行执行
  4. 方法如其名, windows cmd + xxx.exe 即可, 报错信息直接在命令行中输出。

  5. Pyinstaller 打包时, 更改Spec文件
  6. 众所周知, 打包时一般会生成三类文件: build dir(打包流程) / dist dir(最终的软件) / xxx.spec(打包配置文件)。

    而spec文件中可以通过更改 console和debug参数, 输出更为详细的报错信息, 并更改最终软件的形式来达到Debug的效果。但此方法支持的报错不支持软件内部操作报错。


拿到了报错信息, 就要进行对应的操作了, 要是Code本身的报错还好, 看得懂知道在哪儿改。 但是由于Pyinstaller自身Package打包的兼容性问题, 经常会出现其他的一些问题。这里举几个经常会出现的问题。

  1. Module xxx not found
  2. 提示说有什么包没打包进来, 常见的有scipy, xgboost, py4j。

    解决方法: 直接将对应环境site-packages下的对应第三方包文件夹整体复制进来即可。

  3. DLL load failed while importing xxx
  4. 在导入某第三方包时找不到对应的DLL文件, 信息会直接在锁定在import xxx那一行。

    解决方法大致有三种: 该第三方包的版本打包兼容性有误, 卸载并重新安装别的版本即可。 网络方法, 程序目录下缺少python3.dll文件, 将对应环境site-packages下的python3.dll和python3x.dll复制进来即可。 有时候由于代码复杂, 某些DLL会存在没有复制进来的情况。若其他方法均失败, 编写一个基于相同importing package的简易软件并打包。 若后者正常运行, 检查两者目录下DLL文件的异同, 并将后者目录下的DLL全数复制进第一个, 筛选复制也可以。哪怕是原也得打包Qt不是↓

2023-03-05
SpringBoot 3+ 整合MyBatis和MyBatisPlus

困扰了我一个月的问题。

MyBatis和MyBatisPlus 二者之间关联较小, 无非是继承和被继承的关系, 基本上存在冲突也会报错。

结论: 如果SpringBoot版本较高, 请使用MyBatisPlus 3.5.3.1及以上的版本, 否则会提醒缺少sqlSessionFactorysqlSessionTemplate => 即Property 'sqlSessionFactory' or 'sqlSessionTemplate' are required

若无法下载, 需要添加镜像文件

<repositories>
    <repository>
        <id>ossrh</id>
        <name>OSS Snapshot repository</name>
        <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        <release>
            <enabled>false</enabled>
        </release>
        <snapshots>
            <enabled>true</enabled>
        </snapshots>
    </repository>
</repositories>
                        
2023-03-24
MMEditing 简单使用

没研究过也更没做过超分, 但是项目又得用超分, 有人会选择BasicSR, 这边还是选择的MMEditing, 商汤的框架用惯了, 懒得写的时候直接就用。

网上的介绍比较少, 官方1.x的文档更新的也很有限。稍微写一下, 我这边只做了图像超分所以也只能写图像超分。

请注意, 该帮助仅支持0.x版本, 且似乎不支持MMagic。


图像超分顾名思义, 你希望将一张分辨率较低的图像送进网络去获得一张高分辨率的图片。 因此在做数据的时候, 需要高分辨率的原图和低分辨率的对应图片。MMEditing在超分这块仅支持DIV2K数据集, 且为修改过的DIV2K数据集。 整个数据集分为三块: 本身和下采样2 3 4倍的对应图像, Set5和Set14的两个验证集。请注意, Set5 和 Set14数据集也是超分的Benchmark之一, 下载地址请见这里。不过最终发现不要这个也行。

整个数据集文件夹最终的呈现效果如下方所示。请注意: 所有文件必须为png格式, 且不同文件夹下对应图片名称相同。

data
├─DIV2K
│  ├─DIV2K_train_HR         <--- 原分辨率的图片, 训练集
│  ├─DIV2K_train_LR_bicubic
│  │  ├─X2          <--- 2倍下采样
│  │  ├─X3          <--- 3倍下采样
│  │  └─X4          <--- 4倍下采样
│  ├─DIV2K_valid_HR         <--- 原分辨率的图片, 验证集
│  └─DIV2K_valid_LR_bicubic
│      ├─X2          
│      ├─X3
│      └─X4
├─val_set14         <--- 更改名字即可
│  ├─Set14_mod12
│  ├─original
│  ├─Set14_bicLRx2
│  ├─Set14_bicLRx3
│  └─Set14_bicLRx4
└─val_set5          <--- 更改名字即可
    ├─Set5_mod12
    ├─original
    ├─Set5_bicLRx2
    ├─Set5_bicLRx3
    └─Set5_bicLRx4

第一步当然是要做这个数据集。

下采样的方法其实从文件夹命名就可以看出来, 一个bicubic就能搞定。 代码参照这里并做了下修改

import os
import argparse
import cv2
 
parser = argparse.ArgumentParser(description='Downsize images at 2x using bicubic interpolation')
parser.add_argument("-k", "--keepdims", help="keep original image dimensions in downsampled images", action="store_true")
parser.add_argument('--hr_img_dir', type=str, default=None,
                    help='path to high resolution image dir')
parser.add_argument('--lr_img_dir', type=str, default=None,
                    help='path to desired output dir for downsampled images')
args = parser.parse_args()
 
hr_image_dir = args.hr_img_dir
lr_image_dir = args.lr_img_dir
 
 
os.makedirs(lr_image_dir + "/X2", exist_ok=True)
os.makedirs(lr_image_dir + "/X3", exist_ok=True)
os.makedirs(lr_image_dir + "/X4", exist_ok=True)
 
supported_img_formats = (".bmp", ".dib", ".jpeg", ".jpg", ".jpe", ".jp2",
                         ".png", ".pbm", ".pgm", ".ppm", ".sr", ".ras", ".tif",
                         ".tiff")
 

for filename in os.listdir(hr_image_dir):
    if not filename.endswith(supported_img_formats):
        continue
 
    name, ext = os.path.splitext(filename)
 
    hr_img = cv2.imread(os.path.join(hr_image_dir, filename))
    hr_img_dims = (hr_img.shape[1], hr_img.shape[0])
 
    hr_img = cv2.GaussianBlur(hr_img, (0,0), 1, 1)
    lr_image_2x = cv2.resize(hr_img, (0,0), fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
    if args.keepdims:
        lr_image_2x = cv2.resize(lr_image_2x, hr_img_dims, interpolation=cv2.INTER_CUBIC)
 
    cv2.imwrite(os.path.join(lr_image_dir + "/X2", filename.split('.')[0] + ext), lr_image_2x)
 
    lr_img_3x = cv2.resize(hr_img, (0, 0), fx=(1 / 3), fy=(1 / 3),
                           interpolation=cv2.INTER_CUBIC)
    if args.keepdims:
        lr_img_3x = cv2.resize(lr_img_3x, hr_img_dims,
                               interpolation=cv2.INTER_CUBIC)
    cv2.imwrite(os.path.join(lr_image_dir + "/X3", filename.split('.')[0] + ext), lr_img_3x)
 
    lr_img_4x = cv2.resize(hr_img, (0, 0), fx=0.25, fy=0.25,
                           interpolation=cv2.INTER_CUBIC)
    if args.keepdims:
        lr_img_4x = cv2.resize(lr_img_4x, hr_img_dims,
                               interpolation=cv2.INTER_CUBIC)
    cv2.imwrite(os.path.join(lr_image_dir + "/X4", filename.split('.')[0] + ext), lr_img_4x)

第二步在MMEditing内预处理数据集

官方提供的数据集格式有 SRAnnotationDataset --> 以图片文件和独立的ann.txt文件组成; SRFolderDataset --> 直接以图片文件组成; SRLmdbDataset --> 以图片和lmdb文件组成。官方建议的是使用lmdb数据集, io处理速度快, 但是以文件夹为核心的数据集也不错, 起码制作起来简单。

为了扩充数据 (和获取ann.txt及lmdb文件), 首先得进一遍 tools/data/super-resolution/div2k/preprocess_div2k_dataset.py。 具体命令为 (假设你的数据文件夹名为data, 如同上方所示):

python tools/data/super-resolution/div2k/preprocess_div2k_dataset.py --data-root ./data/DIV2K --make-lmdb

之后你的data文件夹多了几个lmdb和_sub文件夹, 如下方所示。_sub文件夹内即为经裁剪后的图片, lmdb就是所需的数据文件了。

data
├── DIV2K
│   ├── DIV2K_train_HR
│   ├── DIV2K_train_HR_sub
│   ├── DIV2K_train_HR_sub.lmdb
│   │   ├── data.mdb
│   │   ├── lock.mdb
│   │   ├── meta_info.txt
│   ├── DIV2K_train_LR_bicubic
│   │   ├── X2
│   │   ├── X3
│   │   ├── X4
│   │   ├── X2_sub
│   │   ├── X3_sub
│   │   ├── X4_sub
│   ├── DIV2K_train_LR_bicubic_X2_sub.lmdb
│   ├── DIV2K_train_LR_bicubic_X3_sub.lmdb
│   ├── DIV2K_train_LR_bicubic_X4_sub.lmdb
│   ├── ...

这时候有人就会问为什么valid没有执行这一步骤, 因为默认确实不执行, 不过也好说。train_pipeline用lmdb, val_pipeline就用folder呗。毕竟val也不做裁剪。


第三步修改配置文件

如果你之前接触过类似于paddlepaddle, mmcv这类多任务的开源框架的话, 改配置对你来说就肯定不算难事情了。 你如果不改网络结构的话, 也就Augmentation / lr_config / loss and optimizer 和日志以及保存部分需要改。

不过既然是超分, 那就难免有不同的地方。所以就列举几个比较特殊的需要更改的点, 大体例子参照官方提供的EDSR模型配置文件。不是很了解的也可以顺便看看对应参数的解释。 请注意, 该配置文件仅训练2倍下采样的数据。可在scale处修改。

如果使用的是 SRAnnotationDataset 数据集格式进行训练, 则需要修改以下位置:

# data -> train -> dataset
lq_folder='data/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
gt_folder='data/DIV2K/DIV2K_train_HR_sub'
ann_file='data/DIV2K/DIV2K_train_HR_sub.lmdb/meta_info.txt'         <--- 经preprocess_div2k_dataset.py处理后生成的txt文件

# data -> val or test -> dataset
lq_folder='./data/val_set5/Set5_bicLRx2'
gt_folder='./data/val_set5/Set5_mod12'          <--- 这里可以改为你想要当作验证和测试的数据文件夹, 因为默认是 SRFolderDataset 因此没有ann.txt

如果使用的是 SRLmdbDataset 数据集格式进行训练, 则需要修改以下位置:

# train_dataset_type
train_dataset_type = 'SRLmdbDataset'

# data -> train -> dataset
lq_folder='data/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb'
gt_folder='data/DIV2K/DIV2K_train_HR_sub.lmdb
# 并注释掉原来的ann_file

# train_pipeline, 前两个dict
dict(
    type='LoadImageFromFile',
    io_backend='lmdb',
    key='lq',
    db_path='data/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb',
    flag='unchanged'),
dict(
    type='LoadImageFromFile',
    io_backend='lmdb',
    key='gt',
    db_path='data/DIV2K/DIV2K_train_HR_sub.lmdb',
    flag='unchanged')

其他修改:

# evaluation
# 删除gpu_collect项

第四步训练, 懂的都懂。


第五步, 直接用api作推理。

如果你曾经使用过mmcv任何工具框架的api, 可能会像下面这样做:

from mmedit.apis import init_model, restoration_inference

def infer_restoration(img, cfg, ckpt):
    model = init_model(config=cfg, checkpoint=ckpt, device="cuda:0")
    infer = restoration_inference(model=model, img=item))
    return infer

if __name__ == '__main__':
    res = infer_restoration(img="xxx.jpg",
                            cfg="xxx.py",
                            ckpt="xxx.pth")

然后你会发现, res矩阵里面全都是在0-1之间的数, 你可能会想到是不是做了归一化。 之后一通 *255 再 reshape 再 tensor.numpy(), 发现出来的结果怎么是原低分辨率的图复制了几次而已。 没辙了。

其实有一个叫做tensor2img的函数是需要一起配套使用的, 阅读源码发现是需要再另外做一系列后处理的, 即如下面所示即可:

from mmedit.apis import init_model, restoration_inference
from mmedit.core import tensor2img

def infer_restoration(img, cfg, ckpt):
    model = init_model(config=cfg, checkpoint=ckpt, device="cuda:0")
    infer = tensor2img(restoration_inference(model=model, img=item)))
    return infer

if __name__ == '__main__':
    res = infer_restoration(img="xxx.jpg",
                            cfg="xxx.py",
                            ckpt="xxx.pth")


Right, wrong... Nobody's got a clue what the difference is in this town. So I'm gonna have more fun... and live crazier than any of 'em.
Goro Majima