// Copyright (c) 2020 tickstep. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package downloader import ( "context" "errors" "fmt" "github.com/tickstep/aliyunpan-api/aliyunpan" "github.com/tickstep/aliyunpan-api/aliyunpan/apierror" "github.com/tickstep/aliyunpan/internal/config" "github.com/tickstep/aliyunpan/library/requester/transfer" "github.com/tickstep/library-go/cachepool" "github.com/tickstep/library-go/logger" "github.com/tickstep/library-go/requester" "github.com/tickstep/library-go/requester/rio/speeds" "io" "net/http" "sync" ) type ( //Worker 工作单元 Worker struct { totalSize int64 // 整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误 wrange *transfer.Range speedsStat *speeds.Speeds globalSpeedsStat *speeds.Speeds // 全局速度统计 id int // work id fileId string // 文件ID driveId string url string // 下载地址 acceptRanges string panClient *config.PanClient client *requester.HTTPClient writerAt io.WriterAt writeMu *sync.Mutex execMu sync.Mutex pauseChan chan struct{} workerCancelFunc context.CancelFunc resetFunc context.CancelFunc readRespBodyCancelFunc func() err error // 错误信息 status WorkerStatus downloadStatus *transfer.DownloadStatus // 总的下载状态 } // WorkerList worker列表 WorkerList []*Worker ) // Duplicate 构造新的列表 func (wl WorkerList) Duplicate() WorkerList { n := make(WorkerList, len(wl)) copy(n, wl) return n } // NewWorker 初始化Worker func NewWorker(id int, driveId string, fileId, durl string, writerAt io.WriterAt, globalSpeedsStat *speeds.Speeds) *Worker { return &Worker{ id: id, url: durl, writerAt: writerAt, fileId: fileId, driveId: driveId, globalSpeedsStat: globalSpeedsStat, } } // ID 返回worker ID func (wer *Worker) ID() int { return wer.id } func (wer *Worker) lazyInit() { if wer.client == nil { wer.client = requester.NewHTTPClient() } if wer.pauseChan == nil { wer.pauseChan = make(chan struct{}) } if wer.wrange == nil { wer.wrange = &transfer.Range{} } if wer.wrange.LoadBegin() == 0 && wer.wrange.LoadEnd() == 0 { // 取消多线程下载 wer.acceptRanges = "" wer.wrange.StoreEnd(-2) } if wer.speedsStat == nil { wer.speedsStat = &speeds.Speeds{} } } // SetTotalSize 设置整个文件的大小, worker请求range时会获取尝试获取该值, 如果不匹配, 则返回错误 func (wer *Worker) SetTotalSize(size int64) { wer.totalSize = size } // SetClient 设置http客户端 func (wer *Worker) SetClient(c *requester.HTTPClient) { wer.client = c } func (wer *Worker) SetPanClient(p *config.PanClient) { wer.panClient = p } // SetAcceptRange 设置AcceptRange func (wer *Worker) SetAcceptRange(acceptRanges string) { wer.acceptRanges = acceptRanges } // SetRange 设置请求范围 func (wer *Worker) SetRange(r *transfer.Range) { if wer.wrange == nil { wer.wrange = r return } wer.wrange.StoreBegin(r.LoadBegin()) wer.wrange.StoreEnd(r.LoadEnd()) } // SetWriteMutex 设置数据写锁 func (wer *Worker) SetWriteMutex(mu *sync.Mutex) { wer.writeMu = mu } // SetDownloadStatus 增加其他需要统计的数据 func (wer *Worker) SetDownloadStatus(downloadStatus *transfer.DownloadStatus) { wer.downloadStatus = downloadStatus } // GetStatus 返回下载状态 func (wer *Worker) GetStatus() WorkerStatuser { // 空接口与空指针不等价 return &wer.status } // GetRange 返回worker范围 func (wer *Worker) GetRange() *transfer.Range { return wer.wrange } // GetSpeedsPerSecond 获取每秒的速度 func (wer *Worker) GetSpeedsPerSecond() int64 { return wer.speedsStat.GetSpeeds() } // Pause 暂停下载 func (wer *Worker) Pause() { wer.lazyInit() if wer.acceptRanges == "" { logger.Verbosef("WARNING: worker unsupport pause") return } if wer.status.statusCode == StatusCodePaused { return } wer.pauseChan <- struct{}{} wer.status.statusCode = StatusCodePaused } // Resume 恢复下载 func (wer *Worker) Resume() { if wer.status.statusCode != StatusCodePaused { return } go wer.Execute() } // Cancel 取消下载 func (wer *Worker) Cancel() error { if wer.workerCancelFunc == nil { return errors.New("cancelFunc not set") } wer.workerCancelFunc() if wer.readRespBodyCancelFunc != nil { wer.readRespBodyCancelFunc() } return nil } // Reset 重设连接 func (wer *Worker) Reset() { if wer.resetFunc == nil { logger.Verbosef("DEBUG: worker: resetFunc not set") return } wer.resetFunc() if wer.readRespBodyCancelFunc != nil { wer.readRespBodyCancelFunc() } wer.ClearStatus() go wer.Execute() } // RefreshDownloadUrl 重新刷新下载链接 func (wer *Worker) RefreshDownloadUrl() { var apierr *apierror.ApiError durl, apierr := wer.panClient.OpenapiPanClient().GetFileDownloadUrl(&aliyunpan.GetFileDownloadUrlParam{DriveId: wer.driveId, FileId: wer.fileId}) if apierr != nil { wer.status.statusCode = StatusCodeTooManyConnections return } wer.url = durl.Url } // Canceled 是否已经取消 func (wer *Worker) Canceled() bool { return wer.status.statusCode == StatusCodeCanceled } // Completed 是否已经完成 func (wer *Worker) Completed() bool { switch wer.status.statusCode { case StatusCodeSuccessed, StatusCodeCanceled: return true default: return false } } // Failed 是否失败 func (wer *Worker) Failed() bool { switch wer.status.statusCode { case StatusCodeFailed, StatusCodeInternalError, StatusCodeTooManyConnections, StatusCodeNetError: return true default: return false } } // ClearStatus 清空状态 func (wer *Worker) ClearStatus() { wer.status.statusCode = StatusCodeInit } // Err 返回worker错误 func (wer *Worker) Err() error { return wer.err } // Execute 执行任务 func (wer *Worker) Execute() { wer.lazyInit() wer.execMu.Lock() defer wer.execMu.Unlock() wer.status.statusCode = StatusCodeInit single := wer.acceptRanges == "" // 如果已暂停, 退出 if wer.status.statusCode == StatusCodePaused { return } if !single { // 已完成 if rlen := wer.wrange.Len(); rlen <= 0 { if rlen < 0 { logger.Verbosef("DEBUG: RangeLen is negative at begin: %v, %d\n", wer.wrange, wer.wrange.Len()) } wer.status.statusCode = StatusCodeSuccessed return } } // zero size file if wer.totalSize == 0 { wer.status.statusCode = StatusCodeSuccessed return } workerCancelCtx, workerCancelFunc := context.WithCancel(context.Background()) wer.workerCancelFunc = workerCancelFunc resetCtx, resetFunc := context.WithCancel(context.Background()) wer.resetFunc = resetFunc wer.status.statusCode = StatusCodePending var resp *http.Response apierr := wer.panClient.OpenapiPanClient().DownloadFileData(wer.url, aliyunpan.FileDownloadRange{ Offset: wer.wrange.Begin, End: wer.wrange.End - 1, }, func(httpMethod, fullUrl string, headers map[string]string) (*http.Response, error) { resp, wer.err = wer.client.Req(httpMethod, fullUrl, nil, headers) if wer.err != nil { return nil, wer.err } return resp, wer.err }) if resp != nil { defer func() { resp.Body.Close() }() wer.readRespBodyCancelFunc = func() { resp.Body.Close() } } if wer.err != nil || apierr != nil { wer.status.statusCode = StatusCodeNetError return } // 判断响应状态 switch resp.StatusCode { case 200, 206: // do nothing, continue wer.status.statusCode = StatusCodeDownloading break case 416: //Requested Range Not Satisfiable fallthrough case 403: // Forbidden fallthrough case 406: // Not Acceptable wer.status.statusCode = StatusCodeNetError wer.err = errors.New(resp.Status) return case 404: wer.status.statusCode = StatusCodeDownloadUrlExpired wer.err = errors.New(resp.Status) return case 429, 509: // Too Many Requests wer.status.SetStatusCode(StatusCodeTooManyConnections) wer.err = errors.New(resp.Status) return default: wer.status.statusCode = StatusCodeNetError wer.err = fmt.Errorf("unexpected http status code, %d, %s", resp.StatusCode, resp.Status) return } var ( contentLength = resp.ContentLength rangeLength = wer.wrange.Len() ) if !single { // 检查请求长度 if contentLength != rangeLength { wer.status.statusCode = StatusCodeNetError wer.err = fmt.Errorf("Content-Length is unexpected: %d, need %d", contentLength, rangeLength) return } // 检查总大小 if wer.totalSize > 0 { total := ParseContentRange(resp.Header.Get("Content-Range")) if total > 0 { if total != wer.totalSize { wer.status.statusCode = StatusCodeInternalError // 这里设置为内部错误, 强制停止下载 wer.err = fmt.Errorf("Content-Range total length is unexpected: %d, need %d", total, wer.totalSize) return } } } } var ( buf = cachepool.SyncPool.Get().([]byte) n, nn int n64, nn64 int64 ) defer cachepool.SyncPool.Put(buf) for { select { case <-workerCancelCtx.Done(): //取消 wer.status.statusCode = StatusCodeCanceled return case <-resetCtx.Done(): //重设连接 wer.status.statusCode = StatusCodeReseted return case <-wer.pauseChan: //暂停 return default: wer.status.statusCode = StatusCodeDownloading // 初始化数据 var readErr error n = 0 // 读取数据 for n < len(buf) && readErr == nil && (single || wer.wrange.Len() > 0) { nn, readErr = resp.Body.Read(buf[n:]) nn64 = int64(nn) // 更新速度统计 if wer.downloadStatus != nil { wer.downloadStatus.AddSpeedsDownloaded(nn64) // 限速在这里阻塞 } wer.speedsStat.Add(nn64) if wer.globalSpeedsStat != nil { wer.globalSpeedsStat.Add(nn64) } n += nn } if n > 0 && readErr == io.EOF { readErr = io.ErrUnexpectedEOF } n64 = int64(n) // 非单线程模式下 if !single { rangeLength = wer.wrange.Len() // 已完成 if rangeLength <= 0 { wer.status.statusCode = StatusCodeCanceled wer.err = errors.New("worker already complete") return } if n64 > rangeLength { // 数据大小不正常 n64 = rangeLength n = int(rangeLength) readErr = io.EOF } } // 写入数据 if wer.writerAt != nil { wer.status.statusCode = StatusCodeWaitToWrite if wer.writeMu != nil { wer.writeMu.Lock() // 加锁, 减轻硬盘的压力 } _, wer.err = wer.writerAt.WriteAt(buf[:n], wer.wrange.Begin) // 写入数据 if wer.err != nil { if wer.writeMu != nil { wer.writeMu.Unlock() //解锁 } wer.status.statusCode = StatusCodeInternalError return } if wer.writeMu != nil { wer.writeMu.Unlock() //解锁 } wer.status.statusCode = StatusCodeDownloading } // 更新下载统计数据 wer.wrange.AddBegin(n64) if wer.downloadStatus != nil { wer.downloadStatus.AddDownloaded(n64) if single { wer.downloadStatus.AddTotalSize(n64) } } if readErr != nil { rlen := wer.wrange.Len() switch { case single && readErr == io.ErrUnexpectedEOF: // 单线程判断下载成功 fallthrough case readErr == io.EOF: fallthrough case rlen <= 0: // 下载完成 // 小于0可能是因为 worker 被 duplicate wer.status.statusCode = StatusCodeSuccessed if rlen < 0 { logger.Verbosef("DEBUG: RangeLen is negative at end: %v, %d\n", wer.wrange, wer.wrange.Len()) } return default: // 其他错误, 返回 wer.status.statusCode = StatusCodeFailed wer.err = readErr return } } } } }