mirror of
https://github.com/tickstep/aliyunpan.git
synced 2025-02-03 05:47:16 +08:00
557 lines
15 KiB
Go
557 lines
15 KiB
Go
// Copyright (c) 2020 tickstep.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
package downloader
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"github.com/tickstep/aliyunpan-api/aliyunpan"
|
|
"github.com/tickstep/aliyunpan-api/aliyunpan/apierror"
|
|
"github.com/tickstep/aliyunpan/cmder/cmdutil"
|
|
"github.com/tickstep/aliyunpan/internal/waitgroup"
|
|
"github.com/tickstep/aliyunpan/library/requester/transfer"
|
|
"github.com/tickstep/library-go/cachepool"
|
|
"github.com/tickstep/library-go/logger"
|
|
"github.com/tickstep/library-go/prealloc"
|
|
"github.com/tickstep/library-go/requester"
|
|
"github.com/tickstep/library-go/requester/rio/speeds"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
// DefaultAcceptRanges 默认的 Accept-Ranges
|
|
DefaultAcceptRanges = "bytes"
|
|
)
|
|
|
|
type (
|
|
// Downloader 下载
|
|
Downloader struct {
|
|
onExecuteEvent requester.Event //开始下载事件
|
|
onSuccessEvent requester.Event //成功下载事件
|
|
onFailedEvent requester.Event //成功下载事件
|
|
onFinishEvent requester.Event //结束下载事件
|
|
onPauseEvent requester.Event //暂停下载事件
|
|
onResumeEvent requester.Event //恢复下载事件
|
|
onCancelEvent requester.Event //取消下载事件
|
|
onDownloadStatusEvent DownloadStatusFunc //状态处理事件
|
|
|
|
monitorCancelFunc context.CancelFunc
|
|
globalSpeedsStat *speeds.Speeds // 全局速度统计
|
|
|
|
fileInfo *aliyunpan.FileEntity // 下载的文件信息
|
|
driveId string
|
|
loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数
|
|
durlCheckFunc DURLCheckFunc // 下载url检测函数
|
|
statusCodeBodyCheckFunc StatusCodeBodyCheckFunc
|
|
executeTime time.Time
|
|
loadBalansers []string
|
|
writer io.WriterAt
|
|
client *requester.HTTPClient
|
|
panClient *aliyunpan.PanClient
|
|
config *Config
|
|
monitor *Monitor
|
|
instanceState *InstanceState
|
|
}
|
|
|
|
// DURLCheckFunc 下载URL检测函数
|
|
DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error)
|
|
// StatusCodeBodyCheckFunc 响应状态码出错的检查函数
|
|
StatusCodeBodyCheckFunc func(respBody io.Reader) error
|
|
)
|
|
|
|
//NewDownloader 初始化Downloader
|
|
func NewDownloader(writer io.WriterAt, config *Config, p *aliyunpan.PanClient, globalSpeedsStat *speeds.Speeds) (der *Downloader) {
|
|
der = &Downloader{
|
|
config: config,
|
|
writer: writer,
|
|
panClient: p,
|
|
globalSpeedsStat: globalSpeedsStat,
|
|
}
|
|
return
|
|
}
|
|
|
|
//SetClient 设置http客户端
|
|
func (der *Downloader) SetFileInfo(f *aliyunpan.FileEntity) {
|
|
der.fileInfo = f
|
|
}
|
|
|
|
func (der *Downloader) SetDriveId(driveId string) {
|
|
der.driveId = driveId
|
|
}
|
|
|
|
//SetClient 设置http客户端
|
|
func (der *Downloader) SetClient(client *requester.HTTPClient) {
|
|
der.client = client
|
|
}
|
|
|
|
// SetLoadBalancerCompareFunc 设置负载均衡检测函数
|
|
func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) {
|
|
der.loadBalancerCompareFunc = f
|
|
}
|
|
|
|
//SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效
|
|
func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) {
|
|
der.statusCodeBodyCheckFunc = f
|
|
}
|
|
|
|
func (der *Downloader) lazyInit() {
|
|
if der.config == nil {
|
|
der.config = NewConfig()
|
|
}
|
|
if der.client == nil {
|
|
der.client = requester.NewHTTPClient()
|
|
der.client.SetTimeout(20 * time.Minute)
|
|
}
|
|
if der.monitor == nil {
|
|
der.monitor = NewMonitor()
|
|
}
|
|
if der.durlCheckFunc == nil {
|
|
der.durlCheckFunc = DefaultDURLCheckFunc
|
|
}
|
|
if der.loadBalancerCompareFunc == nil {
|
|
der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc
|
|
}
|
|
}
|
|
|
|
// SelectParallel 获取合适的 parallel
|
|
func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) {
|
|
isRange := instanceRangeList != nil && len(instanceRangeList) > 0
|
|
if single { // 单线程下载
|
|
parallel = 1
|
|
} else if isRange {
|
|
parallel = len(instanceRangeList)
|
|
} else {
|
|
parallel = maxParallel
|
|
if int64(parallel) > totalSize/int64(MinParallelSize) {
|
|
parallel = int(totalSize/int64(MinParallelSize)) + 1
|
|
}
|
|
}
|
|
|
|
if parallel < 1 {
|
|
parallel = 1
|
|
}
|
|
return
|
|
}
|
|
|
|
// SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen
|
|
func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) {
|
|
// Range 生成器
|
|
if single { // 单线程
|
|
blockSize = -1
|
|
return
|
|
}
|
|
gen := status.RangeListGen()
|
|
if gen == nil {
|
|
switch der.config.Mode {
|
|
case transfer.RangeGenMode_Default:
|
|
gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel)
|
|
blockSize = gen.LoadBlockSize()
|
|
case transfer.RangeGenMode_BlockSize:
|
|
b2 := status.TotalSize()/int64(parallel) + 1
|
|
if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发
|
|
blockSize = der.config.BlockSize
|
|
} else {
|
|
blockSize = b2
|
|
}
|
|
|
|
gen = transfer.NewRangeListGenBlockSize(status.TotalSize(), 0, blockSize)
|
|
default:
|
|
initErr = transfer.ErrUnknownRangeGenMode
|
|
return
|
|
}
|
|
} else {
|
|
blockSize = gen.LoadBlockSize()
|
|
}
|
|
status.SetRangeListGen(gen)
|
|
return
|
|
}
|
|
|
|
// SelectCacheSize 获取合适的 cacheSize
|
|
func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) {
|
|
if blockSize > 0 && int64(confCacheSize) > blockSize {
|
|
// 如果 cache size 过高, 则调低
|
|
cacheSize = int(blockSize)
|
|
} else {
|
|
cacheSize = confCacheSize
|
|
}
|
|
return
|
|
}
|
|
|
|
// DefaultDURLCheckFunc 默认的 DURLCheckFunc
|
|
func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) {
|
|
resp, err = client.Req(http.MethodGet, durl, nil, nil)
|
|
if err != nil {
|
|
if resp != nil {
|
|
resp.Body.Close()
|
|
}
|
|
return 0, nil, err
|
|
}
|
|
return resp.ContentLength, resp, nil
|
|
}
|
|
|
|
func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList {
|
|
var (
|
|
loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1)
|
|
handleLoadBalancer = func(req *http.Request) {
|
|
if req == nil {
|
|
return
|
|
}
|
|
|
|
if der.config.TryHTTP {
|
|
req.URL.Scheme = "http"
|
|
}
|
|
|
|
loadBalancer := &LoadBalancerResponse{
|
|
URL: req.URL.String(),
|
|
}
|
|
|
|
loadBalancerResponses = append(loadBalancerResponses, loadBalancer)
|
|
logger.Verbosef("DEBUG: load balance task: URL: %s", loadBalancer.URL)
|
|
}
|
|
)
|
|
|
|
// 加入第一个
|
|
loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{
|
|
URL: "der.durl",
|
|
})
|
|
|
|
// 负载均衡
|
|
wg := waitgroup.NewWaitGroup(10)
|
|
privTimeout := der.client.Client.Timeout
|
|
der.client.SetTimeout(5 * time.Second)
|
|
for _, loadBalanser := range der.loadBalansers {
|
|
wg.AddDelta()
|
|
go func(loadBalanser string) {
|
|
defer wg.Done()
|
|
|
|
subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser)
|
|
if subResp != nil {
|
|
subResp.Body.Close() // 不读Body, 马上关闭连接
|
|
}
|
|
if subErr != nil {
|
|
logger.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr)
|
|
return
|
|
}
|
|
|
|
// 检测状态码
|
|
switch subResp.StatusCode / 100 {
|
|
case 2: // succeed
|
|
case 4, 5: // error
|
|
var err error
|
|
if der.statusCodeBodyCheckFunc != nil {
|
|
err = der.statusCodeBodyCheckFunc(subResp.Body)
|
|
} else {
|
|
err = errors.New(subResp.Status)
|
|
}
|
|
logger.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err)
|
|
return
|
|
}
|
|
|
|
// 检测长度
|
|
if der.fileInfo.FileSize != subContentLength {
|
|
logger.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n")
|
|
return
|
|
}
|
|
|
|
//if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) {
|
|
// logger.Verbosef("DEBUG: loadBalanser not equal to main server\n")
|
|
// return
|
|
//}
|
|
|
|
handleLoadBalancer(subResp.Request)
|
|
}(loadBalanser)
|
|
}
|
|
wg.Wait()
|
|
der.client.SetTimeout(privTimeout)
|
|
|
|
loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses)
|
|
return loadBalancerResponseList
|
|
}
|
|
|
|
//Execute 开始任务
|
|
func (der *Downloader) Execute() error {
|
|
der.lazyInit()
|
|
|
|
// zero file, no need to download data
|
|
if der.fileInfo.FileSize == 0 {
|
|
cmdutil.Trigger(der.onFinishEvent)
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
loadBalancerResponseList = der.checkLoadBalancers()
|
|
bii *transfer.DownloadInstanceInfo
|
|
)
|
|
|
|
err := der.initInstanceState(der.config.InstanceStateStorageFormat)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
bii = der.instanceState.Get()
|
|
|
|
var (
|
|
isInstance = bii != nil // 是否存在断点信息
|
|
status *transfer.DownloadStatus
|
|
single = false // 默认开启多线程下载
|
|
)
|
|
if !isInstance {
|
|
bii = &transfer.DownloadInstanceInfo{}
|
|
}
|
|
|
|
if bii.DownloadStatus != nil {
|
|
// 使用断点信息的状态
|
|
status = bii.DownloadStatus
|
|
} else {
|
|
// 新建状态
|
|
status = transfer.NewDownloadStatus()
|
|
status.SetTotalSize(der.fileInfo.FileSize)
|
|
}
|
|
|
|
// 设置限速
|
|
if der.config.MaxRate > 0 {
|
|
rl := speeds.NewRateLimit(der.config.MaxRate)
|
|
status.SetRateLimit(rl)
|
|
defer rl.Stop()
|
|
}
|
|
|
|
// 计算文件下载的线程数
|
|
parallel := der.SelectParallel(single, MaxParallelWorkerCount, status.TotalSize(), bii.Ranges) // 实际的下载并行量
|
|
blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel) // 实际的BlockSize
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存
|
|
cachepool.SetSyncPoolSize(cacheSize) // 调整pool大小
|
|
|
|
logger.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize)
|
|
|
|
der.monitor.InitMonitorCapacity(parallel)
|
|
|
|
var writer Writer
|
|
// 尝试修剪文件
|
|
if fder, ok := der.writer.(Fder); ok {
|
|
err = prealloc.PreAlloc(fder.Fd(), status.TotalSize())
|
|
if err != nil {
|
|
logger.Verbosef("DEBUG: truncate file error: %s\n", err)
|
|
}
|
|
}
|
|
writer = der.writer
|
|
|
|
// 数据平均分配给各个线程
|
|
isRange := bii.Ranges != nil && len(bii.Ranges) > 0
|
|
if !isRange {
|
|
// 没有使用断点续传
|
|
// 分配线程
|
|
bii.Ranges = make(transfer.RangeList, 0, parallel)
|
|
if single { // 单线程
|
|
bii.Ranges = append(bii.Ranges, &transfer.Range{Begin: 0, End: der.fileInfo.FileSize})
|
|
} else {
|
|
gen := status.RangeListGen()
|
|
for i := 0; i < cap(bii.Ranges); i++ {
|
|
_, r := gen.GenRange()
|
|
if r == nil {
|
|
break
|
|
}
|
|
bii.Ranges = append(bii.Ranges, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
writeMu = &sync.Mutex{}
|
|
)
|
|
|
|
// 获取下载链接
|
|
var apierr *apierror.ApiError
|
|
durl, apierr := der.panClient.GetFileDownloadUrl(&aliyunpan.GetFileDownloadUrlParam{
|
|
DriveId: der.driveId,
|
|
FileId: der.fileInfo.FileId,
|
|
})
|
|
time.Sleep(time.Duration(200) * time.Millisecond)
|
|
if apierr != nil {
|
|
logger.Verbosef("ERROR: get download url error: %s\n", der.fileInfo.FileId)
|
|
cmdutil.Trigger(der.onCancelEvent)
|
|
return apierr
|
|
}
|
|
if durl == nil || durl.Url == "" || strings.HasPrefix(durl.Url, aliyunpan.IllegalDownloadUrlPrefix) {
|
|
logger.Verbosef("无法获取有效的下载链接: %+v\n", durl)
|
|
cmdutil.Trigger(der.onCancelEvent)
|
|
der.removeInstanceState() // 移除断点续传文件
|
|
cmdutil.Trigger(der.onFailedEvent)
|
|
return ErrFileDownloadForbidden
|
|
}
|
|
|
|
// 初始化下载worker
|
|
for k, r := range bii.Ranges {
|
|
loadBalancer := loadBalancerResponseList.SequentialGet()
|
|
if loadBalancer == nil {
|
|
continue
|
|
}
|
|
|
|
logger.Verbosef("work id: %d, download url: %v\n", k, durl)
|
|
client := requester.NewHTTPClient()
|
|
client.SetKeepAlive(true)
|
|
client.SetTimeout(10 * time.Minute)
|
|
|
|
realUrl := durl.Url
|
|
if der.config.UseInternalUrl {
|
|
realUrl = durl.InternalUrl
|
|
}
|
|
worker := NewWorker(k, der.driveId, der.fileInfo.FileId, realUrl, writer, der.globalSpeedsStat)
|
|
worker.SetClient(client)
|
|
worker.SetPanClient(der.panClient)
|
|
worker.SetWriteMutex(writeMu)
|
|
worker.SetTotalSize(der.fileInfo.FileSize)
|
|
|
|
worker.SetAcceptRange("bytes")
|
|
worker.SetRange(r) // 分配Range
|
|
der.monitor.Append(worker)
|
|
}
|
|
|
|
der.monitor.SetStatus(status)
|
|
|
|
// 服务器不支持断点续传, 或者单线程下载, 都不重载worker
|
|
der.monitor.SetReloadWorker(parallel > 1)
|
|
|
|
moniterCtx, moniterCancelFunc := context.WithCancel(context.Background())
|
|
der.monitorCancelFunc = moniterCancelFunc
|
|
|
|
der.monitor.SetInstanceState(der.instanceState)
|
|
|
|
// 开始执行
|
|
der.executeTime = time.Now()
|
|
cmdutil.Trigger(der.onExecuteEvent)
|
|
der.downloadStatusEvent() // 启动执行状态处理事件
|
|
der.monitor.Execute(moniterCtx)
|
|
|
|
// 检查错误
|
|
err = der.monitor.Err()
|
|
if err == nil { // 成功
|
|
cmdutil.Trigger(der.onSuccessEvent)
|
|
der.removeInstanceState() // 移除断点续传文件
|
|
} else {
|
|
if err == ErrNoWokers && der.fileInfo.FileSize == 0 {
|
|
cmdutil.Trigger(der.onSuccessEvent)
|
|
der.removeInstanceState() // 移除断点续传文件
|
|
}
|
|
}
|
|
|
|
// 执行结束
|
|
cmdutil.Trigger(der.onFinishEvent)
|
|
return err
|
|
}
|
|
|
|
//downloadStatusEvent 执行状态处理事件
|
|
func (der *Downloader) downloadStatusEvent() {
|
|
if der.onDownloadStatusEvent == nil {
|
|
return
|
|
}
|
|
|
|
status := der.monitor.Status()
|
|
go func() {
|
|
ticker := time.NewTicker(1 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-der.monitor.completed:
|
|
return
|
|
case <-ticker.C:
|
|
time.Sleep(500 * time.Millisecond)
|
|
der.onDownloadStatusEvent(status, der.monitor.RangeWorker)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
//Pause 暂停
|
|
func (der *Downloader) Pause() {
|
|
if der.monitor == nil {
|
|
return
|
|
}
|
|
cmdutil.Trigger(der.onPauseEvent)
|
|
der.monitor.Pause()
|
|
}
|
|
|
|
//Resume 恢复
|
|
func (der *Downloader) Resume() {
|
|
if der.monitor == nil {
|
|
return
|
|
}
|
|
cmdutil.Trigger(der.onResumeEvent)
|
|
der.monitor.Resume()
|
|
}
|
|
|
|
//Cancel 取消
|
|
func (der *Downloader) Cancel() {
|
|
if der.monitor == nil {
|
|
return
|
|
}
|
|
cmdutil.Trigger(der.onCancelEvent)
|
|
cmdutil.Trigger(der.monitorCancelFunc)
|
|
}
|
|
|
|
//Failed 失败
|
|
func (der *Downloader) Failed() {
|
|
if der.monitor == nil {
|
|
return
|
|
}
|
|
cmdutil.Trigger(der.onFailedEvent)
|
|
cmdutil.Trigger(der.monitorCancelFunc)
|
|
}
|
|
|
|
//OnExecute 设置开始下载事件
|
|
func (der *Downloader) OnExecute(onExecuteEvent requester.Event) {
|
|
der.onExecuteEvent = onExecuteEvent
|
|
}
|
|
|
|
//OnSuccess 设置成功下载事件
|
|
func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) {
|
|
der.onSuccessEvent = onSuccessEvent
|
|
}
|
|
|
|
//OnFailed 设置失败事件
|
|
func (der *Downloader) OnFailed(onFailedEvent requester.Event) {
|
|
der.onFailedEvent = onFailedEvent
|
|
}
|
|
|
|
//OnFinish 设置结束下载事件
|
|
func (der *Downloader) OnFinish(onFinishEvent requester.Event) {
|
|
der.onFinishEvent = onFinishEvent
|
|
}
|
|
|
|
//OnPause 设置暂停下载事件
|
|
func (der *Downloader) OnPause(onPauseEvent requester.Event) {
|
|
der.onPauseEvent = onPauseEvent
|
|
}
|
|
|
|
//OnResume 设置恢复下载事件
|
|
func (der *Downloader) OnResume(onResumeEvent requester.Event) {
|
|
der.onResumeEvent = onResumeEvent
|
|
}
|
|
|
|
//OnCancel 设置取消下载事件
|
|
func (der *Downloader) OnCancel(onCancelEvent requester.Event) {
|
|
der.onCancelEvent = onCancelEvent
|
|
}
|
|
|
|
//OnDownloadStatusEvent 设置状态处理函数
|
|
func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) {
|
|
der.onDownloadStatusEvent = f
|
|
}
|