aliyunpan/internal/file/downloader/worker.go
2024-08-09 10:17:02 +08:00

539 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright (c) 2020 tickstep.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package downloader
import (
"context"
"errors"
"fmt"
"github.com/tickstep/aliyunpan-api/aliyunpan"
"github.com/tickstep/aliyunpan-api/aliyunpan/apierror"
"github.com/tickstep/aliyunpan/internal/config"
"github.com/tickstep/aliyunpan/library/requester/transfer"
"github.com/tickstep/library-go/cachepool"
"github.com/tickstep/library-go/logger"
"github.com/tickstep/library-go/requester"
"github.com/tickstep/library-go/requester/rio/speeds"
"io"
"io/ioutil"
"net/http"
"strings"
"sync"
"time"
)
type (
//Worker 工作单元
Worker struct {
totalSize int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
wrange *transfer.Range
speedsStat *speeds.Speeds
globalSpeedsStat *speeds.Speeds // 全局速度统计
id int // work id
fileId string // 文件ID
driveId string
url string // 下载地址
acceptRanges string
panClient *config.PanClient
client *requester.HTTPClient
writerAt io.WriterAt
writeMu *sync.Mutex
execMu sync.Mutex
pauseChan chan struct{}
workerCancelFunc context.CancelFunc
resetFunc context.CancelFunc
readRespBodyCancelFunc func()
err error // 错误信息
status WorkerStatus
downloadStatus *transfer.DownloadStatus // 总的下载状态
}
// WorkerList worker列表
WorkerList []*Worker
)
// Duplicate 构造新的列表
func (wl WorkerList) Duplicate() WorkerList {
n := make(WorkerList, len(wl))
copy(n, wl)
return n
}
// NewWorker 初始化Worker
func NewWorker(id int, driveId string, fileId, durl string, writerAt io.WriterAt, globalSpeedsStat *speeds.Speeds) *Worker {
return &Worker{
id: id,
url: durl,
writerAt: writerAt,
fileId: fileId,
driveId: driveId,
globalSpeedsStat: globalSpeedsStat,
}
}
// ID 返回worker ID
func (wer *Worker) ID() int {
return wer.id
}
func (wer *Worker) lazyInit() {
if wer.client == nil {
wer.client = requester.NewHTTPClient()
}
if wer.pauseChan == nil {
wer.pauseChan = make(chan struct{})
}
if wer.wrange == nil {
wer.wrange = &transfer.Range{}
}
if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 {
// 取消多线程下载
wer.acceptRanges = ""
wer.wrange.StoreEnd(-2)
}
if wer.speedsStat == nil {
wer.speedsStat = &speeds.Speeds{}
}
}
// SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误
func (wer *Worker) SetTotalSize(size int64) {
wer.totalSize = size
}
// SetClient 设置http客户端
func (wer *Worker) SetClient(c *requester.HTTPClient) {
wer.client = c
}
func (wer *Worker) SetPanClient(p *config.PanClient) {
wer.panClient = p
}
// SetAcceptRange 设置AcceptRange
func (wer *Worker) SetAcceptRange(acceptRanges string) {
wer.acceptRanges = acceptRanges
}
// SetRange 设置请求范围
func (wer *Worker) SetRange(r *transfer.Range) {
if wer.wrange == nil {
wer.wrange = r
return
}
wer.wrange.StoreBegin(r.LoadBegin())
wer.wrange.StoreEnd(r.LoadEnd())
}
// SetWriteMutex 设置数据写锁
func (wer *Worker) SetWriteMutex(mu *sync.Mutex) {
wer.writeMu = mu
}
// SetDownloadStatus 增加其他需要统计的数据
func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) {
wer.downloadStatus = downloadStatus
}
// SetUrl 更新新的下载路径
func (wer *Worker) SetUrl(newUrl string) {
wer.url = newUrl
}
// GetStatus 返回下载状态
func (wer *Worker) GetStatus() WorkerStatuser {
// 空接口与空指针不等价
return &wer.status
}
// GetRange 返回worker范围
func (wer *Worker) GetRange() *transfer.Range {
return wer.wrange
}
// GetSpeedsPerSecond 获取每秒的速度
func (wer *Worker) GetSpeedsPerSecond() int64 {
return wer.speedsStat.GetSpeeds()
}
// Pause 暂停下载
func (wer *Worker) Pause() {
wer.lazyInit()
if wer.acceptRanges == "" {
logger.Verbosef("WARNING: worker unsupport pause")
return
}
if wer.status.statusCode == StatusCodePaused {
return
}
wer.pauseChan <- struct{}{}
wer.status.statusCode = StatusCodePaused
}
// Resume 恢复下载
func (wer *Worker) Resume() {
if wer.status.statusCode != StatusCodePaused {
return
}
go wer.Execute()
}
// Cancel 取消下载
func (wer *Worker) Cancel() error {
if wer.workerCancelFunc == nil {
return errors.New("cancelFunc not set")
}
wer.workerCancelFunc()
if wer.readRespBodyCancelFunc != nil {
wer.readRespBodyCancelFunc()
}
return nil
}
// Reset 重设连接
func (wer *Worker) Reset() {
if wer.resetFunc == nil {
logger.Verbosef("DEBUG: worker: resetFunc not set")
return
}
wer.resetFunc()
if wer.readRespBodyCancelFunc != nil {
wer.readRespBodyCancelFunc()
}
wer.ClearStatus()
go wer.Execute()
}
// RefreshDownloadUrl 重新刷新下载链接
func (wer *Worker) RefreshDownloadUrl() {
var apierr *apierror.ApiError
durl, apierr := wer.panClient.OpenapiPanClient().GetFileDownloadUrl(&aliyunpan.GetFileDownloadUrlParam{DriveId: wer.driveId, FileId: wer.fileId})
if apierr != nil {
wer.status.statusCode = StatusCodeTooManyConnections
return
}
wer.url = durl.Url
}
// Canceled 是否已经取消
func (wer *Worker) Canceled() bool {
return wer.status.statusCode == StatusCodeCanceled
}
// Completed 是否已经完成
func (wer *Worker) Completed() bool {
switch wer.status.statusCode {
case StatusCodeSuccessed, StatusCodeCanceled:
return true
default:
return false
}
}
// Failed 是否失败
func (wer *Worker) Failed() bool {
switch wer.status.statusCode {
case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError:
return true
default:
return false
}
}
// ClearStatus 清空状态
func (wer *Worker) ClearStatus() {
wer.status.statusCode = StatusCodeInit
}
// Err 返回worker错误
func (wer *Worker) Err() error {
return wer.err
}
// Execute 执行任务
func (wer *Worker) Execute() {
wer.lazyInit()
wer.execMu.Lock()
defer wer.execMu.Unlock()
wer.status.statusCode = StatusCodeInit
single := wer.acceptRanges == ""
// 如果已暂停, 退出
if wer.status.statusCode == StatusCodePaused {
return
}
if !single {
// 已完成
if rlen := wer.wrange.Len(); rlen <= 0 {
if rlen < 0 {
logger.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len())
}
wer.status.statusCode = StatusCodeSuccessed
return
}
}
// zero size file
if wer.totalSize == 0 {
wer.status.statusCode = StatusCodeSuccessed
return
}
workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background())
wer.workerCancelFunc = workerCancelFunc
resetCtx, resetFunc := context.WithCancel(context.Background())
wer.resetFunc = resetFunc
wer.status.statusCode = StatusCodePending
// check url expired or not
if IsUrlExpired(wer.url) {
logger.Verbosef("download url expired, renew url and reset worker: %d\n", wer.ID())
wer.status.statusCode = StatusCodeDownloadUrlExpired
wer.err = errors.New("403")
return
}
// do download data
var resp *http.Response
apierr := wer.panClient.OpenapiPanClient().DownloadFileData(wer.url, aliyunpan.FileDownloadRange{
Offset: wer.wrange.Begin,
End: wer.wrange.End - 1,
}, func(httpMethod, fullUrl string, headers map[string]string) (*http.Response, error) {
resp, wer.err = wer.client.Req(httpMethod, fullUrl, nil, headers)
if wer.err != nil {
return nil, wer.err
}
return resp, wer.err
})
if resp != nil {
defer func() {
resp.Body.Close()
}()
wer.readRespBodyCancelFunc = func() {
resp.Body.Close()
}
}
if wer.err != nil || apierr != nil {
wer.status.statusCode = StatusCodeNetError
return
}
// 判断响应状态
switch resp.StatusCode {
case 200, 206:
// do nothing, continue
wer.status.statusCode = StatusCodeDownloading
break
case 416: //Requested Range Not Satisfiable
fallthrough
case 403: // 链接过期也会返回403
if respBody, e := ioutil.ReadAll(resp.Body); e == nil {
respBodyStr := string(respBody)
if strings.Contains(respBodyStr, "Request has expired") { // 链接已过期
logger.Verboseln("download url return 403 error and expired url response")
wer.status.statusCode = StatusCodeDownloadUrlExpired
wer.err = errors.New(resp.Status)
} else if strings.Contains(respBodyStr, "ExceedMaxConcurrency") { // 遇到限流报错
// 普通应用:文件分片下载的并发数为 3即某用户使用 App 时,可以同时下载 1 个文件的 3 个分片,或者同时下载 3 个文件的各 1 个分片。
// 超过并发,再次调用接口,报错 http status403。示例报错信息如下
// <?xml version="1.0" encoding="UTF-8"?>
//<Error>
// <Code>RequestDeniedByCallback</Code>
// <Message>Callback deny this request reason: ExceedMaxConcurrency</Message>
// <RequestId>66B5720887CECD32333EB8C1</RequestId>
// <HostId>cn-beijing-data.aliyundrive.net</HostId>
// <EC>0007-00000209</EC>
// <RecommendDoc>https://api.aliyun.com/troubleshoot?q=0007-00000209</RecommendDoc>
//</Error>
logger.Verboseln("download url return 403 error and exceed max concurrency response")
wer.status.statusCode = StatusCodeDownloadUrlExceedMaxConcurrency
wer.err = errors.New(resp.Status)
// 遇到限流,本线程延迟后,再重试
time.Sleep(10 * time.Second)
}
}
return
case 406: // Not Acceptable
wer.status.statusCode = StatusCodeNetError
wer.err = errors.New(resp.Status)
return
case 404:
logger.Verboseln("request download url 404 error")
fallthrough
case 429, 509: // Too Many Requests
wer.status.SetStatusCode(StatusCodeTooManyConnections)
wer.err = errors.New(resp.Status)
return
default:
wer.status.statusCode = StatusCodeNetError
wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status)
return
}
var (
contentLength = resp.ContentLength
rangeLength = wer.wrange.Len()
)
if !single {
// 检查请求长度
if contentLength != rangeLength {
wer.status.statusCode = StatusCodeNetError
wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength)
return
}
// 检查总大小
if wer.totalSize > 0 {
total := ParseContentRange(resp.Header.Get("Content-Range"))
if total > 0 {
if total != wer.totalSize {
wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载
wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize)
return
}
}
}
}
var (
buf = cachepool.SyncPool.Get().([]byte)
n, nn int
n64, nn64 int64
)
defer cachepool.SyncPool.Put(buf)
for {
select {
case <-workerCancelCtx.Done(): //取消
wer.status.statusCode = StatusCodeCanceled
return
case <-resetCtx.Done(): //重设连接
wer.status.statusCode = StatusCodeReseted
return
case <-wer.pauseChan: //暂停
return
default:
wer.status.statusCode = StatusCodeDownloading
// 初始化数据
var readErr error
n = 0
// 读取数据
for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) {
nn, readErr = resp.Body.Read(buf[n:])
nn64 = int64(nn)
// 更新速度统计
if wer.downloadStatus != nil {
wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞
}
wer.speedsStat.Add(nn64)
if wer.globalSpeedsStat != nil {
wer.globalSpeedsStat.Add(nn64)
}
n += nn
}
if n > 0 && readErr == io.EOF {
readErr = io.ErrUnexpectedEOF
}
n64 = int64(n)
// 非单线程模式下
if !single {
rangeLength = wer.wrange.Len()
// 已完成
if rangeLength <= 0 {
wer.status.statusCode = StatusCodeCanceled
wer.err = errors.New("worker already complete")
return
}
if n64 > rangeLength {
// 数据大小不正常
n64 = rangeLength
n = int(rangeLength)
readErr = io.EOF
}
}
// 写入数据
if wer.writerAt != nil {
wer.status.statusCode = StatusCodeWaitToWrite
if wer.writeMu != nil {
wer.writeMu.Lock() // 加锁, 减轻硬盘的压力
}
_, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据
if wer.err != nil {
if wer.writeMu != nil {
wer.writeMu.Unlock() //解锁
}
wer.status.statusCode = StatusCodeInternalError
return
}
if wer.writeMu != nil {
wer.writeMu.Unlock() //解锁
}
wer.status.statusCode = StatusCodeDownloading
}
// 更新下载统计数据
wer.wrange.AddBegin(n64)
if wer.downloadStatus != nil {
wer.downloadStatus.AddDownloaded(n64)
if single {
wer.downloadStatus.AddTotalSize(n64)
}
}
if readErr != nil {
rlen := wer.wrange.Len()
switch {
case single && readErr == io.ErrUnexpectedEOF:
// 单线程判断下载成功
fallthrough
case readErr == io.EOF:
fallthrough
case rlen <= 0:
// 下载完成
// 小于0可能是因为 worker 被 duplicate
wer.status.statusCode = StatusCodeSuccessed
if rlen < 0 {
logger.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len())
}
return
default:
// 其他错误, 返回
wer.status.statusCode = StatusCodeFailed
wer.err = readErr
return
}
}
}
}
}