mirror of
https://github.com/tickstep/aliyunpan.git
synced 2025-01-23 22:42:15 +08:00
498 lines
12 KiB
Go
498 lines
12 KiB
Go
// Copyright (c) 2020 tickstep.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
package downloader
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/tickstep/aliyunpan-api/aliyunpan"
|
|
"github.com/tickstep/aliyunpan-api/aliyunpan/apierror"
|
|
"github.com/tickstep/aliyunpan/internal/config"
|
|
"github.com/tickstep/aliyunpan/library/requester/transfer"
|
|
"github.com/tickstep/library-go/cachepool"
|
|
"github.com/tickstep/library-go/logger"
|
|
"github.com/tickstep/library-go/requester"
|
|
"github.com/tickstep/library-go/requester/rio/speeds"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
)
|
|
|
|
type (
|
|
//Worker 工作单元
|
|
Worker struct {
|
|
totalSize int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
|
|
wrange *transfer.Range
|
|
speedsStat *speeds.Speeds
|
|
globalSpeedsStat *speeds.Speeds // 全局速度统计
|
|
id int // work id
|
|
fileId string // 文件ID
|
|
driveId string
|
|
url string // 下载地址
|
|
acceptRanges string
|
|
panClient *config.PanClient
|
|
client *requester.HTTPClient
|
|
writerAt io.WriterAt
|
|
writeMu *sync.Mutex
|
|
execMu sync.Mutex
|
|
|
|
pauseChan chan struct{}
|
|
workerCancelFunc context.CancelFunc
|
|
resetFunc context.CancelFunc
|
|
readRespBodyCancelFunc func()
|
|
err error // 错误信息
|
|
status WorkerStatus
|
|
downloadStatus *transfer.DownloadStatus // 总的下载状态
|
|
}
|
|
|
|
// WorkerList worker列表
|
|
WorkerList []*Worker
|
|
)
|
|
|
|
// Duplicate 构造新的列表
|
|
func (wl WorkerList) Duplicate() WorkerList {
|
|
n := make(WorkerList, len(wl))
|
|
copy(n, wl)
|
|
return n
|
|
}
|
|
|
|
// NewWorker 初始化Worker
|
|
func NewWorker(id int, driveId string, fileId, durl string, writerAt io.WriterAt, globalSpeedsStat *speeds.Speeds) *Worker {
|
|
return &Worker{
|
|
id: id,
|
|
url: durl,
|
|
writerAt: writerAt,
|
|
fileId: fileId,
|
|
driveId: driveId,
|
|
globalSpeedsStat: globalSpeedsStat,
|
|
}
|
|
}
|
|
|
|
// ID 返回worker ID
|
|
func (wer *Worker) ID() int {
|
|
return wer.id
|
|
}
|
|
|
|
func (wer *Worker) lazyInit() {
|
|
if wer.client == nil {
|
|
wer.client = requester.NewHTTPClient()
|
|
}
|
|
if wer.pauseChan == nil {
|
|
wer.pauseChan = make(chan struct{})
|
|
}
|
|
if wer.wrange == nil {
|
|
wer.wrange = &transfer.Range{}
|
|
}
|
|
if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 {
|
|
// 取消多线程下载
|
|
wer.acceptRanges = ""
|
|
wer.wrange.StoreEnd(-2)
|
|
}
|
|
if wer.speedsStat == nil {
|
|
wer.speedsStat = &speeds.Speeds{}
|
|
}
|
|
}
|
|
|
|
// SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
|
|
func (wer *Worker) SetTotalSize(size int64) {
|
|
wer.totalSize = size
|
|
}
|
|
|
|
// SetClient 设置http客户端
|
|
func (wer *Worker) SetClient(c *requester.HTTPClient) {
|
|
wer.client = c
|
|
}
|
|
|
|
func (wer *Worker) SetPanClient(p *config.PanClient) {
|
|
wer.panClient = p
|
|
}
|
|
|
|
// SetAcceptRange 设置AcceptRange
|
|
func (wer *Worker) SetAcceptRange(acceptRanges string) {
|
|
wer.acceptRanges = acceptRanges
|
|
}
|
|
|
|
// SetRange 设置请求范围
|
|
func (wer *Worker) SetRange(r *transfer.Range) {
|
|
if wer.wrange == nil {
|
|
wer.wrange = r
|
|
return
|
|
}
|
|
wer.wrange.StoreBegin(r.LoadBegin())
|
|
wer.wrange.StoreEnd(r.LoadEnd())
|
|
}
|
|
|
|
// SetWriteMutex 设置数据写锁
|
|
func (wer *Worker) SetWriteMutex(mu *sync.Mutex) {
|
|
wer.writeMu = mu
|
|
}
|
|
|
|
// SetDownloadStatus 增加其他需要统计的数据
|
|
func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) {
|
|
wer.downloadStatus = downloadStatus
|
|
}
|
|
|
|
// GetStatus 返回下载状态
|
|
func (wer *Worker) GetStatus() WorkerStatuser {
|
|
// 空接口与空指针不等价
|
|
return &wer.status
|
|
}
|
|
|
|
// GetRange 返回worker范围
|
|
func (wer *Worker) GetRange() *transfer.Range {
|
|
return wer.wrange
|
|
}
|
|
|
|
// GetSpeedsPerSecond 获取每秒的速度
|
|
func (wer *Worker) GetSpeedsPerSecond() int64 {
|
|
return wer.speedsStat.GetSpeeds()
|
|
}
|
|
|
|
// Pause 暂停下载
|
|
func (wer *Worker) Pause() {
|
|
wer.lazyInit()
|
|
if wer.acceptRanges == "" {
|
|
logger.Verbosef("WARNING: worker unsupport pause")
|
|
return
|
|
}
|
|
|
|
if wer.status.statusCode == StatusCodePaused {
|
|
return
|
|
}
|
|
wer.pauseChan <- struct{}{}
|
|
wer.status.statusCode = StatusCodePaused
|
|
}
|
|
|
|
// Resume 恢复下载
|
|
func (wer *Worker) Resume() {
|
|
if wer.status.statusCode != StatusCodePaused {
|
|
return
|
|
}
|
|
go wer.Execute()
|
|
}
|
|
|
|
// Cancel 取消下载
|
|
func (wer *Worker) Cancel() error {
|
|
if wer.workerCancelFunc == nil {
|
|
return errors.New("cancelFunc not set")
|
|
}
|
|
wer.workerCancelFunc()
|
|
if wer.readRespBodyCancelFunc != nil {
|
|
wer.readRespBodyCancelFunc()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Reset 重设连接
|
|
func (wer *Worker) Reset() {
|
|
if wer.resetFunc == nil {
|
|
logger.Verbosef("DEBUG: worker: resetFunc not set")
|
|
return
|
|
}
|
|
wer.resetFunc()
|
|
if wer.readRespBodyCancelFunc != nil {
|
|
wer.readRespBodyCancelFunc()
|
|
}
|
|
wer.ClearStatus()
|
|
go wer.Execute()
|
|
}
|
|
|
|
// RefreshDownloadUrl 重新刷新下载链接
|
|
func (wer *Worker) RefreshDownloadUrl() {
|
|
var apierr *apierror.ApiError
|
|
|
|
durl, apierr := wer.panClient.OpenapiPanClient().GetFileDownloadUrl(&aliyunpan.GetFileDownloadUrlParam{DriveId: wer.driveId, FileId: wer.fileId})
|
|
if apierr != nil {
|
|
wer.status.statusCode = StatusCodeTooManyConnections
|
|
return
|
|
}
|
|
wer.url = durl.Url
|
|
}
|
|
|
|
// Canceled 是否已经取消
|
|
func (wer *Worker) Canceled() bool {
|
|
return wer.status.statusCode == StatusCodeCanceled
|
|
}
|
|
|
|
// Completed 是否已经完成
|
|
func (wer *Worker) Completed() bool {
|
|
switch wer.status.statusCode {
|
|
case StatusCodeSuccessed, StatusCodeCanceled:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Failed 是否失败
|
|
func (wer *Worker) Failed() bool {
|
|
switch wer.status.statusCode {
|
|
case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// ClearStatus 清空状态
|
|
func (wer *Worker) ClearStatus() {
|
|
wer.status.statusCode = StatusCodeInit
|
|
}
|
|
|
|
// Err 返回worker错误
|
|
func (wer *Worker) Err() error {
|
|
return wer.err
|
|
}
|
|
|
|
// Execute 执行任务
|
|
func (wer *Worker) Execute() {
|
|
wer.lazyInit()
|
|
|
|
wer.execMu.Lock()
|
|
defer wer.execMu.Unlock()
|
|
|
|
wer.status.statusCode = StatusCodeInit
|
|
single := wer.acceptRanges == ""
|
|
|
|
// 如果已暂停, 退出
|
|
if wer.status.statusCode == StatusCodePaused {
|
|
return
|
|
}
|
|
|
|
if !single {
|
|
// 已完成
|
|
if rlen := wer.wrange.Len(); rlen <= 0 {
|
|
if rlen < 0 {
|
|
logger.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len())
|
|
}
|
|
wer.status.statusCode = StatusCodeSuccessed
|
|
return
|
|
}
|
|
}
|
|
|
|
// zero size file
|
|
if wer.totalSize == 0 {
|
|
wer.status.statusCode = StatusCodeSuccessed
|
|
return
|
|
}
|
|
|
|
workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background())
|
|
wer.workerCancelFunc = workerCancelFunc
|
|
resetCtx, resetFunc := context.WithCancel(context.Background())
|
|
wer.resetFunc = resetFunc
|
|
|
|
wer.status.statusCode = StatusCodePending
|
|
|
|
var resp *http.Response
|
|
|
|
apierr := wer.panClient.OpenapiPanClient().DownloadFileData(wer.url, aliyunpan.FileDownloadRange{
|
|
Offset: wer.wrange.Begin,
|
|
End: wer.wrange.End - 1,
|
|
}, func(httpMethod, fullUrl string, headers map[string]string) (*http.Response, error) {
|
|
resp, wer.err = wer.client.Req(httpMethod, fullUrl, nil, headers)
|
|
if wer.err != nil {
|
|
return nil, wer.err
|
|
}
|
|
return resp, wer.err
|
|
})
|
|
|
|
if resp != nil {
|
|
defer func() {
|
|
resp.Body.Close()
|
|
}()
|
|
wer.readRespBodyCancelFunc = func() {
|
|
resp.Body.Close()
|
|
}
|
|
}
|
|
if wer.err != nil || apierr != nil {
|
|
wer.status.statusCode = StatusCodeNetError
|
|
return
|
|
}
|
|
|
|
// 判断响应状态
|
|
switch resp.StatusCode {
|
|
case 200, 206:
|
|
// do nothing, continue
|
|
wer.status.statusCode = StatusCodeDownloading
|
|
break
|
|
case 416: //Requested Range Not Satisfiable
|
|
fallthrough
|
|
case 403: // Forbidden
|
|
fallthrough
|
|
case 406: // Not Acceptable
|
|
wer.status.statusCode = StatusCodeNetError
|
|
wer.err = errors.New(resp.Status)
|
|
return
|
|
case 404:
|
|
wer.status.statusCode = StatusCodeDownloadUrlExpired
|
|
wer.err = errors.New(resp.Status)
|
|
return
|
|
case 429, 509: // Too Many Requests
|
|
wer.status.SetStatusCode(StatusCodeTooManyConnections)
|
|
wer.err = errors.New(resp.Status)
|
|
return
|
|
default:
|
|
wer.status.statusCode = StatusCodeNetError
|
|
wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status)
|
|
return
|
|
}
|
|
|
|
var (
|
|
contentLength = resp.ContentLength
|
|
rangeLength = wer.wrange.Len()
|
|
)
|
|
|
|
if !single {
|
|
// 检查请求长度
|
|
if contentLength != rangeLength {
|
|
wer.status.statusCode = StatusCodeNetError
|
|
wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength)
|
|
return
|
|
}
|
|
// 检查总大小
|
|
if wer.totalSize > 0 {
|
|
total := ParseContentRange(resp.Header.Get("Content-Range"))
|
|
if total > 0 {
|
|
if total != wer.totalSize {
|
|
wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载
|
|
wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var (
|
|
buf = cachepool.SyncPool.Get().([]byte)
|
|
n, nn int
|
|
n64, nn64 int64
|
|
)
|
|
defer cachepool.SyncPool.Put(buf)
|
|
|
|
for {
|
|
select {
|
|
case <-workerCancelCtx.Done(): //取消
|
|
wer.status.statusCode = StatusCodeCanceled
|
|
return
|
|
case <-resetCtx.Done(): //重设连接
|
|
wer.status.statusCode = StatusCodeReseted
|
|
return
|
|
case <-wer.pauseChan: //暂停
|
|
return
|
|
default:
|
|
wer.status.statusCode = StatusCodeDownloading
|
|
|
|
// 初始化数据
|
|
var readErr error
|
|
n = 0
|
|
|
|
// 读取数据
|
|
for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) {
|
|
nn, readErr = resp.Body.Read(buf[n:])
|
|
nn64 = int64(nn)
|
|
|
|
// 更新速度统计
|
|
if wer.downloadStatus != nil {
|
|
wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞
|
|
}
|
|
wer.speedsStat.Add(nn64)
|
|
if wer.globalSpeedsStat != nil {
|
|
wer.globalSpeedsStat.Add(nn64)
|
|
}
|
|
n += nn
|
|
}
|
|
|
|
if n > 0 && readErr == io.EOF {
|
|
readErr = io.ErrUnexpectedEOF
|
|
}
|
|
|
|
n64 = int64(n)
|
|
|
|
// 非单线程模式下
|
|
if !single {
|
|
rangeLength = wer.wrange.Len()
|
|
|
|
// 已完成
|
|
if rangeLength <= 0 {
|
|
wer.status.statusCode = StatusCodeCanceled
|
|
wer.err = errors.New("worker already complete")
|
|
return
|
|
}
|
|
|
|
if n64 > rangeLength {
|
|
// 数据大小不正常
|
|
n64 = rangeLength
|
|
n = int(rangeLength)
|
|
readErr = io.EOF
|
|
}
|
|
}
|
|
|
|
// 写入数据
|
|
if wer.writerAt != nil {
|
|
wer.status.statusCode = StatusCodeWaitToWrite
|
|
if wer.writeMu != nil {
|
|
wer.writeMu.Lock() // 加锁, 减轻硬盘的压力
|
|
}
|
|
_, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据
|
|
if wer.err != nil {
|
|
if wer.writeMu != nil {
|
|
wer.writeMu.Unlock() //解锁
|
|
}
|
|
wer.status.statusCode = StatusCodeInternalError
|
|
return
|
|
}
|
|
|
|
if wer.writeMu != nil {
|
|
wer.writeMu.Unlock() //解锁
|
|
}
|
|
wer.status.statusCode = StatusCodeDownloading
|
|
}
|
|
|
|
// 更新下载统计数据
|
|
wer.wrange.AddBegin(n64)
|
|
if wer.downloadStatus != nil {
|
|
wer.downloadStatus.AddDownloaded(n64)
|
|
if single {
|
|
wer.downloadStatus.AddTotalSize(n64)
|
|
}
|
|
}
|
|
|
|
if readErr != nil {
|
|
rlen := wer.wrange.Len()
|
|
switch {
|
|
case single && readErr == io.ErrUnexpectedEOF:
|
|
// 单线程判断下载成功
|
|
fallthrough
|
|
case readErr == io.EOF:
|
|
fallthrough
|
|
case rlen <= 0:
|
|
// 下载完成
|
|
// 小于0可能是因为 worker 被 duplicate
|
|
wer.status.statusCode = StatusCodeSuccessed
|
|
if rlen < 0 {
|
|
logger.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len())
|
|
}
|
|
return
|
|
default:
|
|
// 其他错误, 返回
|
|
wer.status.statusCode = StatusCodeFailed
|
|
wer.err = readErr
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|