mirror of
https://github.com/tickstep/aliyunpan.git
synced 2025-01-23 22:42:15 +08:00
556 lines
16 KiB
Go
556 lines
16 KiB
Go
// Copyright (c) 2020 tickstep.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
package downloader
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"github.com/tickstep/aliyunpan-api/aliyunpan"
|
||
"github.com/tickstep/aliyunpan-api/aliyunpan/apierror"
|
||
"github.com/tickstep/aliyunpan/cmder/cmdutil"
|
||
"github.com/tickstep/aliyunpan/internal/config"
|
||
"github.com/tickstep/aliyunpan/internal/waitgroup"
|
||
"github.com/tickstep/aliyunpan/library/requester/transfer"
|
||
"github.com/tickstep/library-go/cachepool"
|
||
"github.com/tickstep/library-go/logger"
|
||
"github.com/tickstep/library-go/prealloc"
|
||
"github.com/tickstep/library-go/requester"
|
||
"github.com/tickstep/library-go/requester/rio/speeds"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
// DefaultAcceptRanges 默认的 Accept-Ranges
|
||
DefaultAcceptRanges = "bytes"
|
||
)
|
||
|
||
type (
|
||
// Downloader 下载
|
||
Downloader struct {
|
||
onExecuteEvent requester.Event //开始下载事件
|
||
onSuccessEvent requester.Event //成功下载事件
|
||
onFailedEvent requester.Event //成功下载事件
|
||
onFinishEvent requester.Event //结束下载事件
|
||
onPauseEvent requester.Event //暂停下载事件
|
||
onResumeEvent requester.Event //恢复下载事件
|
||
onCancelEvent requester.Event //取消下载事件
|
||
onDownloadStatusEvent DownloadStatusFunc //状态处理事件
|
||
|
||
monitorCancelFunc context.CancelFunc
|
||
globalSpeedsStat *speeds.Speeds // 全局速度统计
|
||
|
||
fileInfo *aliyunpan.FileEntity // 下载的文件信息
|
||
driveId string
|
||
loadBalancerCompareFunc LoadBalancerCompareFunc // 负载均衡检测函数
|
||
durlCheckFunc DURLCheckFunc // 下载url检测函数
|
||
statusCodeBodyCheckFunc StatusCodeBodyCheckFunc
|
||
executeTime time.Time
|
||
loadBalansers []string
|
||
writer io.WriterAt
|
||
client *requester.HTTPClient
|
||
panClient *config.PanClient
|
||
config *Config
|
||
monitor *Monitor
|
||
instanceState *InstanceState
|
||
}
|
||
|
||
// DURLCheckFunc 下载URL检测函数
|
||
DURLCheckFunc func(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error)
|
||
// StatusCodeBodyCheckFunc 响应状态码出错的检查函数
|
||
StatusCodeBodyCheckFunc func(respBody io.Reader) error
|
||
)
|
||
|
||
// NewDownloader 初始化Downloader
|
||
func NewDownloader(writer io.WriterAt, config *Config, p *config.PanClient, globalSpeedsStat *speeds.Speeds) (der *Downloader) {
|
||
der = &Downloader{
|
||
config: config,
|
||
writer: writer,
|
||
panClient: p,
|
||
globalSpeedsStat: globalSpeedsStat,
|
||
}
|
||
return
|
||
}
|
||
|
||
// SetClient 设置http客户端
|
||
func (der *Downloader) SetFileInfo(f *aliyunpan.FileEntity) {
|
||
der.fileInfo = f
|
||
}
|
||
|
||
func (der *Downloader) SetDriveId(driveId string) {
|
||
der.driveId = driveId
|
||
}
|
||
|
||
// SetClient 设置http客户端
|
||
func (der *Downloader) SetClient(client *requester.HTTPClient) {
|
||
der.client = client
|
||
}
|
||
|
||
// SetLoadBalancerCompareFunc 设置负载均衡检测函数
|
||
func (der *Downloader) SetLoadBalancerCompareFunc(f LoadBalancerCompareFunc) {
|
||
der.loadBalancerCompareFunc = f
|
||
}
|
||
|
||
// SetStatusCodeBodyCheckFunc 设置响应状态码出错的检查函数, 当FirstCheckMethod不为HEAD时才有效
|
||
func (der *Downloader) SetStatusCodeBodyCheckFunc(f StatusCodeBodyCheckFunc) {
|
||
der.statusCodeBodyCheckFunc = f
|
||
}
|
||
|
||
func (der *Downloader) lazyInit() {
|
||
if der.config == nil {
|
||
der.config = NewConfig()
|
||
}
|
||
if der.client == nil {
|
||
der.client = requester.NewHTTPClient()
|
||
der.client.SetTimeout(20 * time.Minute)
|
||
}
|
||
if der.monitor == nil {
|
||
der.monitor = NewMonitor()
|
||
}
|
||
if der.durlCheckFunc == nil {
|
||
der.durlCheckFunc = DefaultDURLCheckFunc
|
||
}
|
||
if der.loadBalancerCompareFunc == nil {
|
||
der.loadBalancerCompareFunc = DefaultLoadBalancerCompareFunc
|
||
}
|
||
}
|
||
|
||
// SelectParallel 获取合适的 parallel
|
||
func (der *Downloader) SelectParallel(single bool, maxParallel int, totalSize int64, instanceRangeList transfer.RangeList) (parallel int) {
|
||
isRange := instanceRangeList != nil && len(instanceRangeList) > 0
|
||
if single { // 单线程下载
|
||
parallel = 1
|
||
} else if isRange {
|
||
parallel = len(instanceRangeList)
|
||
} else {
|
||
parallel = maxParallel // 默认为设置为maxParallel个并发线程数
|
||
if int64(parallel) > totalSize/int64(MinParallelSize) { // 如果文件太小不足切片成maxParallel数量的分片,则计算最接近的分片数量
|
||
parallel = int(totalSize/int64(MinParallelSize)) + 1
|
||
}
|
||
}
|
||
|
||
// 其他情况默认使用单线程下载
|
||
if parallel < 1 {
|
||
parallel = 1
|
||
}
|
||
return
|
||
}
|
||
|
||
// SelectBlockSizeAndInitRangeGen 获取合适的 BlockSize, 和初始化 RangeGen
|
||
func (der *Downloader) SelectBlockSizeAndInitRangeGen(single bool, status *transfer.DownloadStatus, parallel int) (blockSize int64, initErr error) {
|
||
// Range 生成器
|
||
if single { // 单线程
|
||
blockSize = -1
|
||
return
|
||
}
|
||
gen := status.RangeListGen()
|
||
if gen == nil {
|
||
switch der.config.Mode {
|
||
case transfer.RangeGenMode_Default:
|
||
gen = transfer.NewRangeListGenDefault(status.TotalSize(), 0, 0, parallel)
|
||
blockSize = gen.LoadBlockSize()
|
||
case transfer.RangeGenMode_BlockSize:
|
||
b2 := status.TotalSize()/int64(parallel) + 1
|
||
if b2 > der.config.BlockSize { // 选小的BlockSize, 以更高并发
|
||
blockSize = der.config.BlockSize
|
||
} else {
|
||
blockSize = b2
|
||
}
|
||
|
||
gen = transfer.NewRangeListGenBlockSize(status.TotalSize(), 0, blockSize)
|
||
default:
|
||
initErr = transfer.ErrUnknownRangeGenMode
|
||
return
|
||
}
|
||
} else {
|
||
blockSize = gen.LoadBlockSize()
|
||
}
|
||
status.SetRangeListGen(gen)
|
||
return
|
||
}
|
||
|
||
// SelectCacheSize 获取合适的 cacheSize
|
||
func (der *Downloader) SelectCacheSize(confCacheSize int, blockSize int64) (cacheSize int) {
|
||
if blockSize > 0 && int64(confCacheSize) > blockSize {
|
||
// 如果 cache size 过高, 则调低
|
||
cacheSize = int(blockSize)
|
||
} else {
|
||
cacheSize = confCacheSize
|
||
}
|
||
return
|
||
}
|
||
|
||
// DefaultDURLCheckFunc 默认的 DURLCheckFunc
|
||
func DefaultDURLCheckFunc(client *requester.HTTPClient, durl string) (contentLength int64, resp *http.Response, err error) {
|
||
resp, err = client.Req(http.MethodGet, durl, nil, nil)
|
||
if err != nil {
|
||
if resp != nil {
|
||
resp.Body.Close()
|
||
}
|
||
return 0, nil, err
|
||
}
|
||
return resp.ContentLength, resp, nil
|
||
}
|
||
|
||
func (der *Downloader) checkLoadBalancers() *LoadBalancerResponseList {
|
||
var (
|
||
loadBalancerResponses = make([]*LoadBalancerResponse, 0, len(der.loadBalansers)+1)
|
||
handleLoadBalancer = func(req *http.Request) {
|
||
if req == nil {
|
||
return
|
||
}
|
||
|
||
if der.config.TryHTTP {
|
||
req.URL.Scheme = "http"
|
||
}
|
||
|
||
loadBalancer := &LoadBalancerResponse{
|
||
URL: req.URL.String(),
|
||
}
|
||
|
||
loadBalancerResponses = append(loadBalancerResponses, loadBalancer)
|
||
logger.Verbosef("DEBUG: load balance task: URL: %s", loadBalancer.URL)
|
||
}
|
||
)
|
||
|
||
// 加入第一个
|
||
loadBalancerResponses = append(loadBalancerResponses, &LoadBalancerResponse{
|
||
URL: "der.durl",
|
||
})
|
||
|
||
// 负载均衡
|
||
wg := waitgroup.NewWaitGroup(10)
|
||
privTimeout := der.client.Client.Timeout
|
||
der.client.SetTimeout(5 * time.Second)
|
||
for _, loadBalanser := range der.loadBalansers {
|
||
wg.AddDelta()
|
||
go func(loadBalanser string) {
|
||
defer wg.Done()
|
||
|
||
subContentLength, subResp, subErr := der.durlCheckFunc(der.client, loadBalanser)
|
||
if subResp != nil {
|
||
subResp.Body.Close() // 不读Body, 马上关闭连接
|
||
}
|
||
if subErr != nil {
|
||
logger.Verbosef("DEBUG: loadBalanser Error: %s\n", subErr)
|
||
return
|
||
}
|
||
|
||
// 检测状态码
|
||
switch subResp.StatusCode / 100 {
|
||
case 2: // succeed
|
||
case 4, 5: // error
|
||
var err error
|
||
if der.statusCodeBodyCheckFunc != nil {
|
||
err = der.statusCodeBodyCheckFunc(subResp.Body)
|
||
} else {
|
||
err = errors.New(subResp.Status)
|
||
}
|
||
logger.Verbosef("DEBUG: loadBalanser Status Error: %s\n", err)
|
||
return
|
||
}
|
||
|
||
// 检测长度
|
||
if der.fileInfo.FileSize != subContentLength {
|
||
logger.Verbosef("DEBUG: loadBalanser Content-Length not equal to main server\n")
|
||
return
|
||
}
|
||
|
||
//if !der.loadBalancerCompareFunc(der.firstInfo.ToMap(), subResp) {
|
||
// logger.Verbosef("DEBUG: loadBalanser not equal to main server\n")
|
||
// return
|
||
//}
|
||
|
||
handleLoadBalancer(subResp.Request)
|
||
}(loadBalanser)
|
||
}
|
||
wg.Wait()
|
||
der.client.SetTimeout(privTimeout)
|
||
|
||
loadBalancerResponseList := NewLoadBalancerResponseList(loadBalancerResponses)
|
||
return loadBalancerResponseList
|
||
}
|
||
|
||
// Execute 开始任务
|
||
func (der *Downloader) Execute() error {
|
||
der.lazyInit()
|
||
|
||
// zero file, no need to download data
|
||
if der.fileInfo.FileSize == 0 {
|
||
cmdutil.Trigger(der.onFinishEvent)
|
||
return nil
|
||
}
|
||
|
||
var (
|
||
loadBalancerResponseList = der.checkLoadBalancers()
|
||
bii *transfer.DownloadInstanceInfo
|
||
)
|
||
|
||
err := der.initInstanceState(der.config.InstanceStateStorageFormat)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
bii = der.instanceState.Get()
|
||
|
||
var (
|
||
isInstance = bii != nil // 是否存在断点信息
|
||
status *transfer.DownloadStatus
|
||
single = false // 默认开启多线程下载,所以当前single值都为false代表不是单线程下载
|
||
)
|
||
if !isInstance {
|
||
bii = &transfer.DownloadInstanceInfo{}
|
||
}
|
||
|
||
if bii.DownloadStatus != nil {
|
||
// 使用断点信息的状态
|
||
status = bii.DownloadStatus
|
||
} else {
|
||
// 新建状态
|
||
status = transfer.NewDownloadStatus()
|
||
status.SetTotalSize(der.fileInfo.FileSize)
|
||
}
|
||
|
||
// 设置限速
|
||
if der.config.MaxRate > 0 {
|
||
rl := speeds.NewRateLimit(der.config.MaxRate)
|
||
status.SetRateLimit(rl)
|
||
defer rl.Stop()
|
||
}
|
||
|
||
// 计算文件下载的并发线程数,计单个文件下载的并发数
|
||
parallel := der.SelectParallel(single, MaxParallelWorkerCount, status.TotalSize(), bii.Ranges) // 实际的下载并行量
|
||
blockSize, err := der.SelectBlockSizeAndInitRangeGen(single, status, parallel) // 实际的BlockSize
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
cacheSize := der.SelectCacheSize(der.config.CacheSize, blockSize) // 实际下载缓存
|
||
cachepool.SetSyncPoolSize(cacheSize) // 调整pool大小
|
||
|
||
logger.Verbosef("DEBUG: download task CREATED: parallel: %d, cache size: %d\n", parallel, cacheSize)
|
||
|
||
der.monitor.InitMonitorCapacity(parallel)
|
||
|
||
var writer Writer
|
||
// 尝试修剪文件
|
||
if fder, ok := der.writer.(Fder); ok {
|
||
err = prealloc.PreAlloc(fder.Fd(), status.TotalSize())
|
||
if err != nil {
|
||
logger.Verbosef("DEBUG: truncate file error: %s\n", err)
|
||
}
|
||
}
|
||
writer = der.writer
|
||
|
||
// 数据平均分配给各个线程
|
||
isRange := bii.Ranges != nil && len(bii.Ranges) > 0
|
||
if !isRange {
|
||
// 没有使用断点续传
|
||
// 分配线程
|
||
bii.Ranges = make(transfer.RangeList, 0, parallel)
|
||
if single { // 单线程
|
||
bii.Ranges = append(bii.Ranges, &transfer.Range{Begin: 0, End: der.fileInfo.FileSize})
|
||
} else {
|
||
gen := status.RangeListGen()
|
||
for i := 0; i < cap(bii.Ranges); i++ {
|
||
_, r := gen.GenRange()
|
||
if r == nil {
|
||
break
|
||
}
|
||
bii.Ranges = append(bii.Ranges, r)
|
||
}
|
||
}
|
||
}
|
||
|
||
var (
|
||
writeMu = &sync.Mutex{}
|
||
)
|
||
|
||
// 获取下载链接
|
||
var apierr *apierror.ApiError
|
||
durl, apierr := der.panClient.OpenapiPanClient().GetFileDownloadUrl(&aliyunpan.GetFileDownloadUrlParam{
|
||
DriveId: der.driveId,
|
||
FileId: der.fileInfo.FileId,
|
||
})
|
||
time.Sleep(time.Duration(200) * time.Millisecond)
|
||
if apierr != nil {
|
||
logger.Verbosef("ERROR: get download url error: %s\n", der.fileInfo.FileId)
|
||
cmdutil.Trigger(der.onCancelEvent)
|
||
return apierr
|
||
}
|
||
if durl == nil || durl.Url == "" || strings.HasPrefix(durl.Url, aliyunpan.IllegalDownloadUrlPrefix) {
|
||
logger.Verbosef("无法获取有效的下载链接: %+v\n", durl)
|
||
cmdutil.Trigger(der.onCancelEvent)
|
||
der.removeInstanceState() // 移除断点续传文件
|
||
cmdutil.Trigger(der.onFailedEvent)
|
||
return ErrFileDownloadForbidden
|
||
}
|
||
|
||
// 初始化下载worker
|
||
for k, r := range bii.Ranges {
|
||
loadBalancer := loadBalancerResponseList.SequentialGet()
|
||
if loadBalancer == nil {
|
||
continue
|
||
}
|
||
|
||
logger.Verbosef("work id: %d, download url: %v\n", k, durl)
|
||
client := requester.NewHTTPClient()
|
||
client.SetKeepAlive(true)
|
||
client.SetTimeout(10 * time.Minute)
|
||
|
||
realUrl := durl.Url
|
||
worker := NewWorker(k, der.driveId, der.fileInfo.FileId, realUrl, writer, der.globalSpeedsStat)
|
||
worker.SetClient(client)
|
||
worker.SetPanClient(der.panClient)
|
||
worker.SetWriteMutex(writeMu)
|
||
worker.SetTotalSize(der.fileInfo.FileSize)
|
||
|
||
worker.SetAcceptRange("bytes")
|
||
worker.SetRange(r) // 分配Range
|
||
der.monitor.Append(worker)
|
||
}
|
||
|
||
der.monitor.SetStatus(status)
|
||
|
||
// 服务器不支持断点续传, 或者单线程下载, 都不重载worker
|
||
der.monitor.SetReloadWorker(parallel > 1)
|
||
|
||
moniterCtx, moniterCancelFunc := context.WithCancel(context.Background())
|
||
der.monitorCancelFunc = moniterCancelFunc
|
||
|
||
der.monitor.SetInstanceState(der.instanceState)
|
||
|
||
// 开始执行
|
||
der.executeTime = time.Now()
|
||
cmdutil.Trigger(der.onExecuteEvent)
|
||
der.downloadStatusEvent() // 启动执行状态处理事件
|
||
der.monitor.Execute(moniterCtx)
|
||
|
||
// 检查错误
|
||
err = der.monitor.Err()
|
||
if err == nil { // 成功
|
||
cmdutil.Trigger(der.onSuccessEvent)
|
||
der.removeInstanceState() // 移除断点续传文件
|
||
} else {
|
||
if err == ErrNoWokers && der.fileInfo.FileSize == 0 {
|
||
cmdutil.Trigger(der.onSuccessEvent)
|
||
der.removeInstanceState() // 移除断点续传文件
|
||
}
|
||
}
|
||
|
||
// 执行结束
|
||
cmdutil.Trigger(der.onFinishEvent)
|
||
return err
|
||
}
|
||
|
||
// downloadStatusEvent 执行状态处理事件
|
||
func (der *Downloader) downloadStatusEvent() {
|
||
if der.onDownloadStatusEvent == nil {
|
||
return
|
||
}
|
||
|
||
status := der.monitor.Status()
|
||
go func() {
|
||
ticker := time.NewTicker(1 * time.Second)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-der.monitor.completed:
|
||
return
|
||
case <-ticker.C:
|
||
time.Sleep(500 * time.Millisecond)
|
||
der.onDownloadStatusEvent(status, der.monitor.RangeWorker)
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// Pause 暂停
|
||
func (der *Downloader) Pause() {
|
||
if der.monitor == nil {
|
||
return
|
||
}
|
||
cmdutil.Trigger(der.onPauseEvent)
|
||
der.monitor.Pause()
|
||
}
|
||
|
||
// Resume 恢复
|
||
func (der *Downloader) Resume() {
|
||
if der.monitor == nil {
|
||
return
|
||
}
|
||
cmdutil.Trigger(der.onResumeEvent)
|
||
der.monitor.Resume()
|
||
}
|
||
|
||
// Cancel 取消
|
||
func (der *Downloader) Cancel() {
|
||
if der.monitor == nil {
|
||
return
|
||
}
|
||
cmdutil.Trigger(der.onCancelEvent)
|
||
cmdutil.Trigger(der.monitorCancelFunc)
|
||
}
|
||
|
||
// Failed 失败
|
||
func (der *Downloader) Failed() {
|
||
if der.monitor == nil {
|
||
return
|
||
}
|
||
cmdutil.Trigger(der.onFailedEvent)
|
||
cmdutil.Trigger(der.monitorCancelFunc)
|
||
}
|
||
|
||
// OnExecute 设置开始下载事件
|
||
func (der *Downloader) OnExecute(onExecuteEvent requester.Event) {
|
||
der.onExecuteEvent = onExecuteEvent
|
||
}
|
||
|
||
// OnSuccess 设置成功下载事件
|
||
func (der *Downloader) OnSuccess(onSuccessEvent requester.Event) {
|
||
der.onSuccessEvent = onSuccessEvent
|
||
}
|
||
|
||
// OnFailed 设置失败事件
|
||
func (der *Downloader) OnFailed(onFailedEvent requester.Event) {
|
||
der.onFailedEvent = onFailedEvent
|
||
}
|
||
|
||
// OnFinish 设置结束下载事件
|
||
func (der *Downloader) OnFinish(onFinishEvent requester.Event) {
|
||
der.onFinishEvent = onFinishEvent
|
||
}
|
||
|
||
// OnPause 设置暂停下载事件
|
||
func (der *Downloader) OnPause(onPauseEvent requester.Event) {
|
||
der.onPauseEvent = onPauseEvent
|
||
}
|
||
|
||
// OnResume 设置恢复下载事件
|
||
func (der *Downloader) OnResume(onResumeEvent requester.Event) {
|
||
der.onResumeEvent = onResumeEvent
|
||
}
|
||
|
||
// OnCancel 设置取消下载事件
|
||
func (der *Downloader) OnCancel(onCancelEvent requester.Event) {
|
||
der.onCancelEvent = onCancelEvent
|
||
}
|
||
|
||
// OnDownloadStatusEvent 设置状态处理函数
|
||
func (der *Downloader) OnDownloadStatusEvent(f DownloadStatusFunc) {
|
||
der.onDownloadStatusEvent = f
|
||
}
|