bump deps

This commit is contained in:
Jeff Mitchell
2016-12-13 19:12:26 -05:00
parent 9eb3368015
commit 35a1917bc7
121 changed files with 18139 additions and 2022 deletions

View File

@@ -23,7 +23,7 @@ All of the methods of this package use exponential backoff to retry calls
that fail with certain errors, as described in
https://cloud.google.com/storage/docs/exponential-backoff.
Note: This package is experimental and may make backwards-incompatible changes.
Note: This package is in beta. Some backwards-incompatible changes may occur.
Creating a Client

View File

@@ -123,6 +123,7 @@ type BlobProperties struct {
CopyCompletionTime string `xml:"CopyCompletionTime"`
CopyStatusDescription string `xml:"CopyStatusDescription"`
LeaseStatus string `xml:"LeaseStatus"`
LeaseState string `xml:"LeaseState"`
}
// BlobHeaders contains various properties of a blob and is an entry
@@ -434,7 +435,7 @@ func (b BlobStorageClient) ListContainers(params ListContainersParameters) (Cont
headers := b.client.getStandardHeaders()
var out ContainerListResponse
resp, err := b.client.exec("GET", uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return out, err
}
@@ -471,24 +472,22 @@ func (b BlobStorageClient) CreateContainerIfNotExists(name string, access Contai
}
func (b BlobStorageClient) createContainer(name string, access ContainerAccessType) (*storageResponse, error) {
verb := "PUT"
uri := b.client.getEndpoint(blobServiceName, pathForContainer(name), url.Values{"restype": {"container"}})
headers := b.client.getStandardHeaders()
if access != "" {
headers[ContainerAccessHeader] = string(access)
}
return b.client.exec(verb, uri, headers, nil)
return b.client.exec(http.MethodPut, uri, headers, nil)
}
// ContainerExists returns true if a container with given name exists
// on the storage account, otherwise returns false.
func (b BlobStorageClient) ContainerExists(name string) (bool, error) {
verb := "HEAD"
uri := b.client.getEndpoint(blobServiceName, pathForContainer(name), url.Values{"restype": {"container"}})
headers := b.client.getStandardHeaders()
resp, err := b.client.exec(verb, uri, headers, nil)
resp, err := b.client.exec(http.MethodHead, uri, headers, nil)
if resp != nil {
defer resp.body.Close()
if resp.statusCode == http.StatusOK || resp.statusCode == http.StatusNotFound {
@@ -528,9 +527,9 @@ func (b BlobStorageClient) SetContainerPermissions(container string, containerPe
var resp *storageResponse
if accessPolicyXML != "" {
headers["Content-Length"] = strconv.Itoa(len(accessPolicyXML))
resp, err = b.client.exec("PUT", uri, headers, strings.NewReader(accessPolicyXML))
resp, err = b.client.exec(http.MethodPut, uri, headers, strings.NewReader(accessPolicyXML))
} else {
resp, err = b.client.exec("PUT", uri, headers, nil)
resp, err = b.client.exec(http.MethodPut, uri, headers, nil)
}
if err != nil {
@@ -538,9 +537,7 @@ func (b BlobStorageClient) SetContainerPermissions(container string, containerPe
}
if resp != nil {
defer func() {
err = resp.body.Close()
}()
defer resp.body.Close()
if resp.statusCode != http.StatusOK {
return errors.New("Unable to set permissions")
@@ -568,7 +565,7 @@ func (b BlobStorageClient) GetContainerPermissions(container string, timeout int
headers[leaseID] = leaseID
}
resp, err := b.client.exec("GET", uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return nil, err
}
@@ -576,9 +573,7 @@ func (b BlobStorageClient) GetContainerPermissions(container string, timeout int
// containerAccess. Blob, Container, empty
containerAccess := resp.headers.Get(http.CanonicalHeaderKey(ContainerAccessHeader))
defer func() {
err = resp.body.Close()
}()
defer resp.body.Close()
var out AccessPolicy
err = xmlUnmarshal(resp.body, &out.SignedIdentifiersList)
@@ -624,11 +619,10 @@ func (b BlobStorageClient) DeleteContainerIfExists(name string) (bool, error) {
}
func (b BlobStorageClient) deleteContainer(name string) (*storageResponse, error) {
verb := "DELETE"
uri := b.client.getEndpoint(blobServiceName, pathForContainer(name), url.Values{"restype": {"container"}})
headers := b.client.getStandardHeaders()
return b.client.exec(verb, uri, headers, nil)
return b.client.exec(http.MethodDelete, uri, headers, nil)
}
// ListBlobs returns an object that contains list of blobs in the container,
@@ -643,7 +637,7 @@ func (b BlobStorageClient) ListBlobs(container string, params ListBlobsParameter
headers := b.client.getStandardHeaders()
var out BlobListResponse
resp, err := b.client.exec("GET", uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return out, err
}
@@ -656,10 +650,9 @@ func (b BlobStorageClient) ListBlobs(container string, params ListBlobsParameter
// BlobExists returns true if a blob with given name exists on the specified
// container of the storage account.
func (b BlobStorageClient) BlobExists(container, name string) (bool, error) {
verb := "HEAD"
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{})
headers := b.client.getStandardHeaders()
resp, err := b.client.exec(verb, uri, headers, nil)
resp, err := b.client.exec(http.MethodHead, uri, headers, nil)
if resp != nil {
defer resp.body.Close()
if resp.statusCode == http.StatusOK || resp.statusCode == http.StatusNotFound {
@@ -713,7 +706,6 @@ func (b BlobStorageClient) GetBlobRange(container, name, bytesRange string, extr
}
func (b BlobStorageClient) getBlobRange(container, name, bytesRange string, extraHeaders map[string]string) (*storageResponse, error) {
verb := "GET"
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{})
headers := b.client.getStandardHeaders()
@@ -725,7 +717,7 @@ func (b BlobStorageClient) getBlobRange(container, name, bytesRange string, extr
headers[k] = v
}
resp, err := b.client.exec(verb, uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return nil, err
}
@@ -737,7 +729,7 @@ func (b BlobStorageClient) leaseCommonPut(container string, name string, headers
params := url.Values{"comp": {"lease"}}
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return nil, err
}
@@ -764,7 +756,7 @@ func (b BlobStorageClient) SnapshotBlob(container string, name string, timeout i
}
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return nil, err
}
@@ -903,17 +895,16 @@ func (b BlobStorageClient) RenewLease(container string, name string, currentLeas
// GetBlobProperties provides various information about the specified
// blob. See https://msdn.microsoft.com/en-us/library/azure/dd179394.aspx
func (b BlobStorageClient) GetBlobProperties(container, name string) (*BlobProperties, error) {
verb := "HEAD"
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{})
headers := b.client.getStandardHeaders()
resp, err := b.client.exec(verb, uri, headers, nil)
resp, err := b.client.exec(http.MethodHead, uri, headers, nil)
if err != nil {
return nil, err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{http.StatusOK}); err != nil {
if err = checkRespCode(resp.statusCode, []int{http.StatusOK}); err != nil {
return nil, err
}
@@ -953,6 +944,7 @@ func (b BlobStorageClient) GetBlobProperties(container, name string) (*BlobPrope
CopyStatus: resp.headers.Get("x-ms-copy-status"),
BlobType: BlobType(resp.headers.Get("x-ms-blob-type")),
LeaseStatus: resp.headers.Get("x-ms-lease-status"),
LeaseState: resp.headers.Get("x-ms-lease-state"),
}, nil
}
@@ -975,7 +967,7 @@ func (b BlobStorageClient) SetBlobProperties(container, name string, blobHeaders
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -1004,7 +996,7 @@ func (b BlobStorageClient) SetBlobMetadata(container, name string, metadata map[
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -1024,7 +1016,7 @@ func (b BlobStorageClient) GetBlobMetadata(container, name string) (map[string]s
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), params)
headers := b.client.getStandardHeaders()
resp, err := b.client.exec("GET", uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return nil, err
}
@@ -1083,7 +1075,7 @@ func (b BlobStorageClient) CreateBlockBlobFromReader(container, name string, siz
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, blob)
resp, err := b.client.exec(http.MethodPut, uri, headers, blob)
if err != nil {
return err
}
@@ -1120,7 +1112,7 @@ func (b BlobStorageClient) PutBlockWithLength(container, name, blockID string, s
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, blob)
resp, err := b.client.exec(http.MethodPut, uri, headers, blob)
if err != nil {
return err
}
@@ -1139,7 +1131,7 @@ func (b BlobStorageClient) PutBlockList(container, name string, blocks []Block)
headers := b.client.getStandardHeaders()
headers["Content-Length"] = fmt.Sprintf("%v", len(blockListXML))
resp, err := b.client.exec("PUT", uri, headers, strings.NewReader(blockListXML))
resp, err := b.client.exec(http.MethodPut, uri, headers, strings.NewReader(blockListXML))
if err != nil {
return err
}
@@ -1156,7 +1148,7 @@ func (b BlobStorageClient) GetBlockList(container, name string, blockType BlockL
headers := b.client.getStandardHeaders()
var out BlockListResponse
resp, err := b.client.exec("GET", uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return out, err
}
@@ -1182,7 +1174,7 @@ func (b BlobStorageClient) PutPageBlob(container, name string, size int64, extra
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -1217,7 +1209,7 @@ func (b BlobStorageClient) PutPage(container, name string, startByte, endByte in
}
headers["Content-Length"] = fmt.Sprintf("%v", contentLength)
resp, err := b.client.exec("PUT", uri, headers, data)
resp, err := b.client.exec(http.MethodPut, uri, headers, data)
if err != nil {
return err
}
@@ -1235,13 +1227,13 @@ func (b BlobStorageClient) GetPageRanges(container, name string) (GetPageRangesR
headers := b.client.getStandardHeaders()
var out GetPageRangesResponse
resp, err := b.client.exec("GET", uri, headers, nil)
resp, err := b.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return out, err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{http.StatusOK}); err != nil {
if err = checkRespCode(resp.statusCode, []int{http.StatusOK}); err != nil {
return out, err
}
err = xmlUnmarshal(resp.body, &out)
@@ -1262,7 +1254,7 @@ func (b BlobStorageClient) PutAppendBlob(container, name string, extraHeaders ma
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -1285,7 +1277,7 @@ func (b BlobStorageClient) AppendBlock(container, name string, chunk []byte, ext
headers[k] = v
}
resp, err := b.client.exec("PUT", uri, headers, bytes.NewReader(chunk))
resp, err := b.client.exec(http.MethodPut, uri, headers, bytes.NewReader(chunk))
if err != nil {
return err
}
@@ -1320,7 +1312,7 @@ func (b BlobStorageClient) StartBlobCopy(container, name, sourceBlob string) (st
headers := b.client.getStandardHeaders()
headers["x-ms-copy-source"] = sourceBlob
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return "", err
}
@@ -1355,7 +1347,7 @@ func (b BlobStorageClient) AbortBlobCopy(container, name, copyID, currentLeaseID
headers[leaseID] = currentLeaseID
}
resp, err := b.client.exec("PUT", uri, headers, nil)
resp, err := b.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -1423,14 +1415,13 @@ func (b BlobStorageClient) DeleteBlobIfExists(container, name string, extraHeade
}
func (b BlobStorageClient) deleteBlob(container, name string, extraHeaders map[string]string) (*storageResponse, error) {
verb := "DELETE"
uri := b.client.getEndpoint(blobServiceName, pathForBlob(container, name), url.Values{})
headers := b.client.getStandardHeaders()
for k, v := range extraHeaders {
headers[k] = v
}
return b.client.exec(verb, uri, headers, nil)
return b.client.exec(http.MethodDelete, uri, headers, nil)
}
// helper method to construct the path to a container given its name

View File

@@ -387,11 +387,8 @@ func (c Client) exec(verb, url string, headers map[string]string, body io.Reader
if err != nil {
return nil, err
}
headers["Authorization"] = authHeader
if err != nil {
return nil, err
}
headers["Authorization"] = authHeader
req, err := http.NewRequest(verb, url, body)
if err != nil {
return nil, errors.New("azure/storage: error creating request: " + err.Error())
@@ -428,12 +425,13 @@ func (c Client) exec(verb, url string, headers map[string]string, body io.Reader
return nil, err
}
requestID := resp.Header.Get("x-ms-request-id")
if len(respBody) == 0 {
// no error in response body
err = fmt.Errorf("storage: service returned without a response body (%s)", resp.Status)
// no error in response body, might happen in HEAD requests
err = serviceErrFromStatusCode(resp.StatusCode, resp.Status, requestID)
} else {
// response contains storage service error object, unmarshal
storageErr, errIn := serviceErrFromXML(respBody, resp.StatusCode, resp.Header.Get("x-ms-request-id"))
storageErr, errIn := serviceErrFromXML(respBody, resp.StatusCode, requestID)
if err != nil { // error unmarshaling the error response
err = errIn
}
@@ -482,8 +480,8 @@ func (c Client) execInternalJSON(verb, url string, headers map[string]string, bo
}
if len(respBody) == 0 {
// no error in response body
err = fmt.Errorf("storage: service returned without a response body (%d)", resp.StatusCode)
// no error in response body, might happen in HEAD requests
err = serviceErrFromStatusCode(resp.StatusCode, resp.Status, resp.Header.Get("x-ms-request-id"))
return respToRet, err
}
// try unmarshal as odata.error json
@@ -535,6 +533,15 @@ func serviceErrFromXML(body []byte, statusCode int, requestID string) (AzureStor
return storageErr, nil
}
func serviceErrFromStatusCode(code int, status string, requestID string) AzureStorageServiceError {
return AzureStorageServiceError{
StatusCode: code,
Code: status,
RequestID: requestID,
Message: "no response body was available for error status code",
}
}
func (e AzureStorageServiceError) Error() string {
return fmt.Sprintf("storage: service returned error: StatusCode=%d, ErrorCode=%s, ErrorMessage=%s, RequestId=%s, QueryParameterName=%s, QueryParameterValue=%s",
e.StatusCode, e.Code, e.Message, e.RequestID, e.QueryParameterName, e.QueryParameterValue)

View File

@@ -250,7 +250,6 @@ func (f FileServiceClient) ListDirsAndFiles(path string, params ListDirsAndFiles
if err != nil {
return out, err
}
defer resp.body.Close()
err = xmlUnmarshal(resp.body, &out)
return out, err
@@ -302,7 +301,6 @@ func (f FileServiceClient) ListShares(params ListSharesParameters) (ShareListRes
if err != nil {
return out, err
}
defer resp.body.Close()
err = xmlUnmarshal(resp.body, &out)
return out, err
@@ -400,7 +398,6 @@ func (f FileServiceClient) modifyRange(path string, bytes io.Reader, fileRange F
if err != nil {
return err
}
defer resp.body.Close()
return checkRespCode(resp.statusCode, []int{http.StatusCreated})
}

View File

@@ -156,7 +156,7 @@ func (c QueueServiceClient) SetMetadata(name string, metadata map[string]string)
headers[userDefinedMetadataHeaderPrefix+k] = v
}
resp, err := c.client.exec("PUT", uri, headers, nil)
resp, err := c.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -179,7 +179,7 @@ func (c QueueServiceClient) GetMetadata(name string) (QueueMetadataResponse, err
qm.UserDefinedMetadata = make(map[string]string)
uri := c.client.getEndpoint(queueServiceName, pathForQueue(name), url.Values{"comp": []string{"metadata"}})
headers := c.client.getStandardHeaders()
resp, err := c.client.exec("GET", uri, headers, nil)
resp, err := c.client.exec(http.MethodGet, uri, headers, nil)
if err != nil {
return qm, err
}
@@ -212,7 +212,7 @@ func (c QueueServiceClient) GetMetadata(name string) (QueueMetadataResponse, err
func (c QueueServiceClient) CreateQueue(name string) error {
uri := c.client.getEndpoint(queueServiceName, pathForQueue(name), url.Values{})
headers := c.client.getStandardHeaders()
resp, err := c.client.exec("PUT", uri, headers, nil)
resp, err := c.client.exec(http.MethodPut, uri, headers, nil)
if err != nil {
return err
}
@@ -225,7 +225,7 @@ func (c QueueServiceClient) CreateQueue(name string) error {
// See https://msdn.microsoft.com/en-us/library/azure/dd179436.aspx
func (c QueueServiceClient) DeleteQueue(name string) error {
uri := c.client.getEndpoint(queueServiceName, pathForQueue(name), url.Values{})
resp, err := c.client.exec("DELETE", uri, c.client.getStandardHeaders(), nil)
resp, err := c.client.exec(http.MethodDelete, uri, c.client.getStandardHeaders(), nil)
if err != nil {
return err
}
@@ -236,7 +236,7 @@ func (c QueueServiceClient) DeleteQueue(name string) error {
// QueueExists returns true if a queue with given name exists.
func (c QueueServiceClient) QueueExists(name string) (bool, error) {
uri := c.client.getEndpoint(queueServiceName, pathForQueue(name), url.Values{"comp": {"metadata"}})
resp, err := c.client.exec("GET", uri, c.client.getStandardHeaders(), nil)
resp, err := c.client.exec(http.MethodGet, uri, c.client.getStandardHeaders(), nil)
if resp != nil && (resp.statusCode == http.StatusOK || resp.statusCode == http.StatusNotFound) {
return resp.statusCode == http.StatusOK, nil
}
@@ -256,7 +256,7 @@ func (c QueueServiceClient) PutMessage(queue string, message string, params PutM
}
headers := c.client.getStandardHeaders()
headers["Content-Length"] = strconv.Itoa(nn)
resp, err := c.client.exec("POST", uri, headers, body)
resp, err := c.client.exec(http.MethodPost, uri, headers, body)
if err != nil {
return err
}
@@ -269,7 +269,7 @@ func (c QueueServiceClient) PutMessage(queue string, message string, params PutM
// See https://msdn.microsoft.com/en-us/library/azure/dd179454.aspx
func (c QueueServiceClient) ClearMessages(queue string) error {
uri := c.client.getEndpoint(queueServiceName, pathForQueueMessages(queue), url.Values{})
resp, err := c.client.exec("DELETE", uri, c.client.getStandardHeaders(), nil)
resp, err := c.client.exec(http.MethodDelete, uri, c.client.getStandardHeaders(), nil)
if err != nil {
return err
}
@@ -284,7 +284,7 @@ func (c QueueServiceClient) ClearMessages(queue string) error {
func (c QueueServiceClient) GetMessages(queue string, params GetMessagesParameters) (GetMessagesResponse, error) {
var r GetMessagesResponse
uri := c.client.getEndpoint(queueServiceName, pathForQueueMessages(queue), params.getParameters())
resp, err := c.client.exec("GET", uri, c.client.getStandardHeaders(), nil)
resp, err := c.client.exec(http.MethodGet, uri, c.client.getStandardHeaders(), nil)
if err != nil {
return r, err
}
@@ -300,7 +300,7 @@ func (c QueueServiceClient) GetMessages(queue string, params GetMessagesParamete
func (c QueueServiceClient) PeekMessages(queue string, params PeekMessagesParameters) (PeekMessagesResponse, error) {
var r PeekMessagesResponse
uri := c.client.getEndpoint(queueServiceName, pathForQueueMessages(queue), params.getParameters())
resp, err := c.client.exec("GET", uri, c.client.getStandardHeaders(), nil)
resp, err := c.client.exec(http.MethodGet, uri, c.client.getStandardHeaders(), nil)
if err != nil {
return r, err
}
@@ -315,7 +315,7 @@ func (c QueueServiceClient) PeekMessages(queue string, params PeekMessagesParame
func (c QueueServiceClient) DeleteMessage(queue, messageID, popReceipt string) error {
uri := c.client.getEndpoint(queueServiceName, pathForMessage(queue, messageID), url.Values{
"popreceipt": {popReceipt}})
resp, err := c.client.exec("DELETE", uri, c.client.getStandardHeaders(), nil)
resp, err := c.client.exec(http.MethodDelete, uri, c.client.getStandardHeaders(), nil)
if err != nil {
return err
}
@@ -335,7 +335,7 @@ func (c QueueServiceClient) UpdateMessage(queue string, messageID string, messag
}
headers := c.client.getStandardHeaders()
headers["Content-Length"] = fmt.Sprintf("%d", nn)
resp, err := c.client.exec("PUT", uri, headers, body)
resp, err := c.client.exec(http.MethodPut, uri, headers, body)
if err != nil {
return err
}

View File

@@ -45,7 +45,7 @@ func (c *TableServiceClient) QueryTables() ([]AzureTable, error) {
headers := c.getStandardHeaders()
headers["Content-Length"] = "0"
resp, err := c.client.execTable("GET", uri, headers, nil)
resp, err := c.client.execTable(http.MethodGet, uri, headers, nil)
if err != nil {
return nil, err
}
@@ -56,7 +56,9 @@ func (c *TableServiceClient) QueryTables() ([]AzureTable, error) {
}
buf := new(bytes.Buffer)
buf.ReadFrom(resp.body)
if _, err := buf.ReadFrom(resp.body); err != nil {
return nil, err
}
var respArray queryTablesResponse
if err := json.Unmarshal(buf.Bytes(), &respArray); err != nil {
@@ -88,7 +90,7 @@ func (c *TableServiceClient) CreateTable(table AzureTable) error {
headers["Content-Length"] = fmt.Sprintf("%d", buf.Len())
resp, err := c.client.execTable("POST", uri, headers, buf)
resp, err := c.client.execTable(http.MethodPost, uri, headers, buf)
if err != nil {
return err
@@ -114,7 +116,7 @@ func (c *TableServiceClient) DeleteTable(table AzureTable) error {
headers["Content-Length"] = "0"
resp, err := c.client.execTable("DELETE", uri, headers, nil)
resp, err := c.client.execTable(http.MethodDelete, uri, headers, nil)
if err != nil {
return err

View File

@@ -98,7 +98,7 @@ func (c *TableServiceClient) QueryTableEntities(tableName AzureTable, previousCo
headers["Content-Length"] = "0"
resp, err := c.client.execTable("GET", uri, headers, nil)
resp, err := c.client.execTable(http.MethodGet, uri, headers, nil)
if err != nil {
return nil, nil, err
@@ -106,12 +106,9 @@ func (c *TableServiceClient) QueryTableEntities(tableName AzureTable, previousCo
contToken := extractContinuationTokenFromHeaders(resp.headers)
if err != nil {
return nil, contToken, err
}
defer resp.body.Close()
if err := checkRespCode(resp.statusCode, []int{http.StatusOK}); err != nil {
if err = checkRespCode(resp.statusCode, []int{http.StatusOK}); err != nil {
return nil, contToken, err
}
@@ -127,13 +124,11 @@ func (c *TableServiceClient) QueryTableEntities(tableName AzureTable, previousCo
// The function fails if there is an entity with the same
// PartitionKey and RowKey in the table.
func (c *TableServiceClient) InsertEntity(table AzureTable, entity TableEntity) error {
var err error
if sc, err := c.execTable(table, entity, false, "POST"); err != nil {
if sc, err := c.execTable(table, entity, false, http.MethodPost); err != nil {
return checkRespCode(sc, []int{http.StatusCreated})
}
return err
return nil
}
func (c *TableServiceClient) execTable(table AzureTable, entity TableEntity, specifyKeysInURL bool, method string) (int, error) {
@@ -152,10 +147,7 @@ func (c *TableServiceClient) execTable(table AzureTable, entity TableEntity, spe
headers["Content-Length"] = fmt.Sprintf("%d", buf.Len())
var err error
var resp *odataResponse
resp, err = c.client.execTable(method, uri, headers, &buf)
resp, err := c.client.execTable(method, uri, headers, &buf)
if err != nil {
return 0, err
@@ -170,12 +162,10 @@ func (c *TableServiceClient) execTable(table AzureTable, entity TableEntity, spe
// one passed as parameter. The function fails if there is no entity
// with the same PartitionKey and RowKey in the table.
func (c *TableServiceClient) UpdateEntity(table AzureTable, entity TableEntity) error {
var err error
if sc, err := c.execTable(table, entity, true, "PUT"); err != nil {
if sc, err := c.execTable(table, entity, true, http.MethodPut); err != nil {
return checkRespCode(sc, []int{http.StatusNoContent})
}
return err
return nil
}
// MergeEntity merges the contents of an entity with the
@@ -183,12 +173,10 @@ func (c *TableServiceClient) UpdateEntity(table AzureTable, entity TableEntity)
// The function fails if there is no entity
// with the same PartitionKey and RowKey in the table.
func (c *TableServiceClient) MergeEntity(table AzureTable, entity TableEntity) error {
var err error
if sc, err := c.execTable(table, entity, true, "MERGE"); err != nil {
return checkRespCode(sc, []int{http.StatusNoContent})
}
return err
return nil
}
// DeleteEntityWithoutCheck deletes the entity matching by
@@ -214,7 +202,7 @@ func (c *TableServiceClient) DeleteEntity(table AzureTable, entity TableEntity,
headers["Content-Length"] = "0"
headers["If-Match"] = ifMatch
resp, err := c.client.execTable("DELETE", uri, headers, nil)
resp, err := c.client.execTable(http.MethodDelete, uri, headers, nil)
if err != nil {
return err
@@ -231,23 +219,19 @@ func (c *TableServiceClient) DeleteEntity(table AzureTable, entity TableEntity,
// InsertOrReplaceEntity inserts an entity in the specified table
// or replaced the existing one.
func (c *TableServiceClient) InsertOrReplaceEntity(table AzureTable, entity TableEntity) error {
var err error
if sc, err := c.execTable(table, entity, true, "PUT"); err != nil {
if sc, err := c.execTable(table, entity, true, http.MethodPut); err != nil {
return checkRespCode(sc, []int{http.StatusNoContent})
}
return err
return nil
}
// InsertOrMergeEntity inserts an entity in the specified table
// or merges the existing one.
func (c *TableServiceClient) InsertOrMergeEntity(table AzureTable, entity TableEntity) error {
var err error
if sc, err := c.execTable(table, entity, true, "MERGE"); err != nil {
return checkRespCode(sc, []int{http.StatusNoContent})
}
return err
return nil
}
func injectPartitionAndRowKeys(entity TableEntity, buf *bytes.Buffer) error {
@@ -340,8 +324,12 @@ func deserializeEntity(retType reflect.Type, reader io.Reader) ([]TableEntity, e
}
// Reset PartitionKey and RowKey
tEntries[i].SetPartitionKey(pKey)
tEntries[i].SetRowKey(rKey)
if err := tEntries[i].SetPartitionKey(pKey); err != nil {
return nil, err
}
if err := tEntries[i].SetRowKey(rKey); err != nil {
return nil, err
}
}
return tEntries, nil

View File

@@ -11,9 +11,11 @@ import (
// A Config provides configuration to a service client instance.
type Config struct {
Config *aws.Config
Handlers request.Handlers
Endpoint, SigningRegion string
Config *aws.Config
Handlers request.Handlers
Endpoint string
SigningRegion string
SigningName string
}
// ConfigProvider provides a generic way for a service client to receive

View File

@@ -5,6 +5,7 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
)
// UseServiceDefaultRetries instructs the config to use the service's own
@@ -48,6 +49,10 @@ type Config struct {
// endpoint for a client.
Endpoint *string
// The resolver to use for looking up endpoints for AWS service clients
// to use based on region.
EndpointResolver endpoints.Resolver
// The region to send requests to. This parameter is required and must
// be configured globally or on a per-client basis unless otherwise
// noted. A full list of regions is found in the "Regions and Endpoints"
@@ -235,6 +240,13 @@ func (c *Config) WithEndpoint(endpoint string) *Config {
return c
}
// WithEndpointResolver sets a config EndpointResolver value returning a
// Config pointer for chaining.
func (c *Config) WithEndpointResolver(resolver endpoints.Resolver) *Config {
c.EndpointResolver = resolver
return c
}
// WithRegion sets a config Region value returning a Config pointer for
// chaining.
func (c *Config) WithRegion(region string) *Config {
@@ -357,6 +369,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.Endpoint = other.Endpoint
}
if other.EndpointResolver != nil {
dst.EndpointResolver = other.EndpointResolver
}
if other.Region != nil {
dst.Region = other.Region
}

View File

@@ -19,8 +19,8 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
)
// A Defaults provides a collection of default values for SDK clients.
@@ -56,7 +56,8 @@ func Config() *aws.Config {
WithMaxRetries(aws.UseServiceDefaultRetries).
WithLogger(aws.NewDefaultLogger()).
WithLogLevel(aws.LogOff).
WithSleepDelay(time.Sleep)
WithSleepDelay(time.Sleep).
WithEndpointResolver(endpoints.DefaultResolver())
}
// Handlers returns the default request handlers.
@@ -120,11 +121,14 @@ func ecsCredProvider(cfg aws.Config, handlers request.Handlers, uri string) cred
}
func ec2RoleProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
endpoint, signingRegion := endpoints.EndpointForRegion(ec2metadata.ServiceName,
aws.StringValue(cfg.Region), true, false)
resolver := cfg.EndpointResolver
if resolver == nil {
resolver = endpoints.DefaultResolver()
}
e, _ := resolver.EndpointFor(endpoints.Ec2metadataServiceID, "")
return &ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.NewClient(cfg, handlers, endpoint, signingRegion),
Client: ec2metadata.NewClient(cfg, handlers, e.URL, e.SigningRegion),
ExpiryWindow: 5 * time.Minute,
}
}

View File

@@ -0,0 +1,133 @@
package endpoints
import (
"encoding/json"
"fmt"
"io"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type modelDefinition map[string]json.RawMessage
// A DecodeModelOptions are the options for how the endpoints model definition
// are decoded.
type DecodeModelOptions struct {
SkipCustomizations bool
}
// Set combines all of the option functions together.
func (d *DecodeModelOptions) Set(optFns ...func(*DecodeModelOptions)) {
for _, fn := range optFns {
fn(d)
}
}
// DecodeModel unmarshals a Regions and Endpoint model definition file into
// a endpoint Resolver. If the file format is not supported, or an error occurs
// when unmarshaling the model an error will be returned.
//
// Casting the return value of this func to a EnumPartitions will
// allow you to get a list of the partitions in the order the endpoints
// will be resolved in.
//
// resolver, err := endpoints.DecodeModel(reader)
//
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
// for _, p := range partitions {
// // ... inspect partitions
// }
func DecodeModel(r io.Reader, optFns ...func(*DecodeModelOptions)) (Resolver, error) {
var opts DecodeModelOptions
opts.Set(optFns...)
// Get the version of the partition file to determine what
// unmarshaling model to use.
modelDef := modelDefinition{}
if err := json.NewDecoder(r).Decode(&modelDef); err != nil {
return nil, newDecodeModelError("failed to decode endpoints model", err)
}
var version string
if b, ok := modelDef["version"]; ok {
version = string(b)
} else {
return nil, newDecodeModelError("endpoints version not found in model", nil)
}
if version == "3" {
return decodeV3Endpoints(modelDef, opts)
}
return nil, newDecodeModelError(
fmt.Sprintf("endpoints version %s, not supported", version), nil)
}
func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resolver, error) {
b, ok := modelDef["partitions"]
if !ok {
return nil, newDecodeModelError("endpoints model missing partitions", nil)
}
ps := partitions{}
if err := json.Unmarshal(b, &ps); err != nil {
return nil, newDecodeModelError("failed to decode endpoints model", err)
}
if opts.SkipCustomizations {
return ps, nil
}
// Customization
for i := 0; i < len(ps); i++ {
p := &ps[i]
custAddEC2Metadata(p)
custAddS3DualStack(p)
custRmIotDataService(p)
}
return ps, nil
}
func custAddS3DualStack(p *partition) {
if p.ID != "aws" {
return
}
s, ok := p.Services["s3"]
if !ok {
return
}
s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services["s3"] = s
}
func custAddEC2Metadata(p *partition) {
p.Services["ec2metadata"] = service{
IsRegionalized: boxedFalse,
PartitionEndpoint: "aws-global",
Endpoints: endpoints{
"aws-global": endpoint{
Hostname: "169.254.169.254/latest",
Protocols: []string{"http"},
},
},
}
}
func custRmIotDataService(p *partition) {
delete(p.Services, "data.iot")
}
type decodeModelError struct {
awsError
}
func newDecodeModelError(msg string, err error) decodeModelError {
return decodeModelError{
awsError: awserr.New("DecodeEndpointsModelError", msg, err),
}
}

File diff suppressed because it is too large Load Diff

66
vendor/github.com/aws/aws-sdk-go/aws/endpoints/doc.go generated vendored Normal file
View File

@@ -0,0 +1,66 @@
// Package endpoints provides the types and functionality for defining regions
// and endpoints, as well as querying those definitions.
//
// The SDK's Regions and Endpoints metadata is code generated into the endpoints
// package, and is accessible via the DefaultResolver function. This function
// returns a endpoint Resolver will search the metadata and build an associated
// endpoint if one is found. The default resolver will search all partitions
// known by the SDK. e.g AWS Standard (aws), AWS China (aws-cn), and
// AWS GovCloud (US) (aws-us-gov).
// .
//
// Enumerating Regions and Endpoint Metadata
//
// Casting the Resolver returned by DefaultResolver to a EnumPartitions interface
// will allow you to get access to the list of underlying Partitions with the
// Partitions method. This is helpful if you want to limit the SDK's endpoint
// resolving to a single partition, or enumerate regions, services, and endpoints
// in the partition.
//
// resolver := endpoints.DefaultResolver()
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
//
// for _, p := range partitions {
// fmt.Println("Regions for", p.Name)
// for id, _ := range p.Regions() {
// fmt.Println("*", id)
// }
//
// fmt.Println("Services for", p.Name)
// for id, _ := range p.Services() {
// fmt.Println("*", id)
// }
// }
//
// Using Custom Endpoints
//
// The endpoints package also gives you the ability to use your own logic how
// endpoints are resolved. This is a great way to define a custom endpoint
// for select services, without passing that logic down through your code.
//
// If a type implements the Resolver interface it can be used to resolve
// endpoints. To use this with the SDK's Session and Config set the value
// of the type to the EndpointsResolver field of aws.Config when initializing
// the session, or service client.
//
// In addition the ResolverFunc is a wrapper for a func matching the signature
// of Resolver.EndpointFor, converting it to a type that satisfies the
// Resolver interface.
//
//
// myCustomResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
// if service == endpoints.S3ServiceID {
// return endpoints.ResolvedEndpoint{
// URL: "s3.custom.endpoint.com",
// SigningRegion: "custom-signing-region",
// }, nil
// }
//
// return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
// }
//
// sess := session.Must(session.NewSession(&aws.Config{
// Region: aws.String("us-west-2"),
// EndpointResolver: endpoints.ResolverFunc(myCustomResolver),
// }))
package endpoints

View File

@@ -0,0 +1,369 @@
package endpoints
import (
"fmt"
"regexp"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Options provide the configuration needed to direct how the
// endpoints will be resolved.
type Options struct {
// DisableSSL forces the endpoint to be resolved as HTTP.
// instead of HTTPS if the service supports it.
DisableSSL bool
// Sets the resolver to resolve the endpoint as a dualstack endpoint
// for the service. If dualstack support for a service is not known and
// StrictMatching is not enabled a dualstack endpoint for the service will
// be returned. This endpoint may not be valid. If StrictMatching is
// enabled only services that are known to support dualstack will return
// dualstack endpoints.
UseDualStack bool
// Enables strict matching of services and regions resolved endpoints.
// If the partition doesn't enumerate the exact service and region an
// error will be returned. This option will prevent returning endpoints
// that look valid, but may not resolve to any real endpoint.
StrictMatching bool
}
// Set combines all of the option functions together.
func (o *Options) Set(optFns ...func(*Options)) {
for _, fn := range optFns {
fn(o)
}
}
// DisableSSLOption sets the DisableSSL options. Can be used as a functional
// option when resolving endpoints.
func DisableSSLOption(o *Options) {
o.DisableSSL = true
}
// UseDualStackOption sets the UseDualStack option. Can be used as a functional
// option when resolving endpoints.
func UseDualStackOption(o *Options) {
o.UseDualStack = true
}
// StrictMatchingOption sets the StrictMatching option. Can be used as a functional
// option when resolving endpoints.
func StrictMatchingOption(o *Options) {
o.StrictMatching = true
}
// A Resolver provides the interface for functionality to resolve endpoints.
// The build in Partition and DefaultResolver return value satisfy this interface.
type Resolver interface {
EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error)
}
// ResolverFunc is a helper utility that wraps a function so it satisfies the
// Resolver interface. This is useful when you want to add additional endpoint
// resolving logic, or stub out specific endpoints with custom values.
type ResolverFunc func(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error)
// EndpointFor wraps the ResolverFunc function to satisfy the Resolver interface.
func (fn ResolverFunc) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return fn(service, region, opts...)
}
var schemeRE = regexp.MustCompile("^([^:]+)://")
// AddScheme adds the HTTP or HTTPS schemes to a endpoint URL if there is no
// scheme. If disableSSL is true HTTP will set HTTP instead of the default HTTPS.
//
// If disableSSL is set, it will only set the URL's scheme if the URL does not
// contain a scheme.
func AddScheme(endpoint string, disableSSL bool) string {
if !schemeRE.MatchString(endpoint) {
scheme := "https"
if disableSSL {
scheme = "http"
}
endpoint = fmt.Sprintf("%s://%s", scheme, endpoint)
}
return endpoint
}
// EnumPartitions a provides a way to retrieve the underlying partitions that
// make up the SDK's default Resolver, or any resolver decoded from a model
// file.
//
// Use this interface with DefaultResolver and DecodeModels to get the list of
// Partitions.
type EnumPartitions interface {
Partitions() []Partition
}
// A Partition provides the ability to enumerate the partition's regions
// and services.
type Partition struct {
id string
p *partition
}
// ID returns the identifier of the partition.
func (p *Partition) ID() string { return p.id }
// EndpointFor attempts to resolve the endpoint based on service and region.
// See Options for information on configuring how the endpoint is resolved.
//
// If the service cannot be found in the metadata the UnknownServiceError
// error will be returned. This validation will occur regardless if
// StrictMatching is enabled.
//
// When resolving endpoints you can choose to enable StrictMatching. This will
// require the provided service and region to be known by the partition.
// If the endpoint cannot be strictly resolved an error will be returned. This
// mode is useful to ensure the endpoint resolved is valid. Without
// StrictMatching enabled the enpoint returned my look valid but may not work.
// StrictMatching requires the SDK to be updated if you want to take advantage
// of new regions and services expantions.
//
// Errors that can be returned.
// * UnknownServiceError
// * UnknownEndpointError
func (p *Partition) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return p.p.EndpointFor(service, region, opts...)
}
// Regions returns a map of Regions indexed by their ID. This is useful for
// enumerating over the regions in a partition.
func (p *Partition) Regions() map[string]Region {
rs := map[string]Region{}
for id := range p.p.Regions {
rs[id] = Region{
id: id,
p: p.p,
}
}
return rs
}
// Services returns a map of Service indexed by their ID. This is useful for
// enumerating over the services in a partition.
func (p *Partition) Services() map[string]Service {
ss := map[string]Service{}
for id := range p.p.Services {
ss[id] = Service{
id: id,
p: p.p,
}
}
return ss
}
// A Region provides information about a region, and ability to resolve an
// endpoint from the context of a region, given a service.
type Region struct {
id, desc string
p *partition
}
// ID returns the region's identifier.
func (r *Region) ID() string { return r.id }
// ResolveEndpoint resolves an endpoint from the context of the region given
// a service. See Partition.EndpointFor for usage and errors that can be returned.
func (r *Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return r.p.EndpointFor(service, r.id, opts...)
}
// Services returns a list of all services that are known to be in this region.
func (r *Region) Services() map[string]Service {
ss := map[string]Service{}
for id, s := range r.p.Services {
if _, ok := s.Endpoints[r.id]; ok {
ss[id] = Service{
id: id,
p: r.p,
}
}
}
return ss
}
// A Service provides information about a service, and ability to resolve an
// endpoint from the context of a service, given a region.
type Service struct {
id string
p *partition
}
// ID returns the identifier for the service.
func (s *Service) ID() string { return s.id }
// ResolveEndpoint resolves an endpoint from the context of a service given
// a region. See Partition.EndpointFor for usage and errors that can be returned.
func (s *Service) ResolveEndpoint(region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
return s.p.EndpointFor(s.id, region, opts...)
}
// Endpoints returns a map of Endpoints indexed by their ID for all known
// endpoints for a service.
func (s *Service) Endpoints() map[string]Endpoint {
es := map[string]Endpoint{}
for id := range s.p.Services[s.id].Endpoints {
es[id] = Endpoint{
id: id,
serviceID: s.id,
p: s.p,
}
}
return es
}
// A Endpoint provides information about endpoints, and provides the ability
// to resolve that endpoint for the service, and the region the endpoint
// represents.
type Endpoint struct {
id string
serviceID string
p *partition
}
// ID returns the identifier for an endpoint.
func (e *Endpoint) ID() string { return e.id }
// ServiceID returns the identifier the endpoint belongs to.
func (e *Endpoint) ServiceID() string { return e.serviceID }
// ResolveEndpoint resolves an endpoint from the context of a service and
// region the endpoint represents. See Partition.EndpointFor for usage and
// errors that can be returned.
func (e *Endpoint) ResolveEndpoint(opts ...func(*Options)) (ResolvedEndpoint, error) {
return e.p.EndpointFor(e.serviceID, e.id, opts...)
}
// A ResolvedEndpoint is an endpoint that has been resolved based on a partition
// service, and region.
type ResolvedEndpoint struct {
// The endpoint URL
URL string
// The region that should be used for signing requests.
SigningRegion string
// The service name that should be used for signing requests.
SigningName string
// The signing method that should be used for signing requests.
SigningMethod string
}
// So that the Error interface type can be included as an anonymous field
// in the requestError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A EndpointNotFoundError is returned when in StrictMatching mode, and the
// endpoint for the service and region cannot be found in any of the partitions.
type EndpointNotFoundError struct {
awsError
Partition string
Service string
Region string
}
//// NewEndpointNotFoundError builds and returns NewEndpointNotFoundError.
//func NewEndpointNotFoundError(p, s, r string) EndpointNotFoundError {
// return EndpointNotFoundError{
// awsError: awserr.New("EndpointNotFoundError", "unable to find endpoint", nil),
// Partition: p,
// Service: s,
// Region: r,
// }
//}
//
//// Error returns string representation of the error.
//func (e EndpointNotFoundError) Error() string {
// extra := fmt.Sprintf("partition: %q, service: %q, region: %q",
// e.Partition, e.Service, e.Region)
// return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
//}
//
//// String returns the string representation of the error.
//func (e EndpointNotFoundError) String() string {
// return e.Error()
//}
// A UnknownServiceError is returned when the service does not resolve to an
// endpoint. Includes a list of all known services for the partition. Returned
// when a partition does not support the service.
type UnknownServiceError struct {
awsError
Partition string
Service string
Known []string
}
// NewUnknownServiceError builds and returns UnknownServiceError.
func NewUnknownServiceError(p, s string, known []string) UnknownServiceError {
return UnknownServiceError{
awsError: awserr.New("UnknownServiceError",
"could not resolve endpoint for unknown service", nil),
Partition: p,
Service: s,
Known: known,
}
}
// String returns the string representation of the error.
func (e UnknownServiceError) Error() string {
extra := fmt.Sprintf("partition: %q, service: %q",
e.Partition, e.Service)
if len(e.Known) > 0 {
extra += fmt.Sprintf(", known: %v", e.Known)
}
return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
func (e UnknownServiceError) String() string {
return e.Error()
}
// A UnknownEndpointError is returned when in StrictMatching mode and the
// service is valid, but the region does not resolve to an endpoint. Includes
// a list of all known endpoints for the service.
type UnknownEndpointError struct {
awsError
Partition string
Service string
Region string
Known []string
}
// NewUnknownEndpointError builds and returns UnknownEndpointError.
func NewUnknownEndpointError(p, s, r string, known []string) UnknownEndpointError {
return UnknownEndpointError{
awsError: awserr.New("UnknownEndpointError",
"could not resolve endpoint", nil),
Partition: p,
Service: s,
Region: r,
Known: known,
}
}
// String returns the string representation of the error.
func (e UnknownEndpointError) Error() string {
extra := fmt.Sprintf("partition: %q, service: %q, region: %q",
e.Partition, e.Service, e.Region)
if len(e.Known) > 0 {
extra += fmt.Sprintf(", known: %v", e.Known)
}
return awserr.SprintError(e.Code(), e.Message(), extra, e.OrigErr())
}
// String returns the string representation of the error.
func (e UnknownEndpointError) String() string {
return e.Error()
}

View File

@@ -0,0 +1,301 @@
package endpoints
import (
"fmt"
"regexp"
"strconv"
"strings"
)
type partitions []partition
func (ps partitions) EndpointFor(service, region string, opts ...func(*Options)) (ResolvedEndpoint, error) {
var opt Options
opt.Set(opts...)
for i := 0; i < len(ps); i++ {
if !ps[i].canResolveEndpoint(service, region, opt.StrictMatching) {
continue
}
return ps[i].EndpointFor(service, region, opts...)
}
// If loose matching fallback to first partition format to use
// when resolving the endpoint.
if !opt.StrictMatching && len(ps) > 0 {
return ps[0].EndpointFor(service, region, opts...)
}
return ResolvedEndpoint{}, NewUnknownEndpointError("all partitions", service, region, []string{})
}
// Partitions satisfies the EnumPartitions interface and returns a list
// of Partitions representing each partition represented in the SDK's
// endpoints model.
func (ps partitions) Partitions() []Partition {
parts := make([]Partition, 0, len(ps))
for i := 0; i < len(ps); i++ {
parts = append(parts, ps[i].Partition())
}
return parts
}
type partition struct {
ID string `json:"partition"`
Name string `json:"partitionName"`
DNSSuffix string `json:"dnsSuffix"`
RegionRegex regionRegex `json:"regionRegex"`
Defaults endpoint `json:"defaults"`
Regions regions `json:"regions"`
Services services `json:"services"`
}
func (p partition) Partition() Partition {
return Partition{
id: p.ID,
p: &p,
}
}
func (p partition) canResolveEndpoint(service, region string, strictMatch bool) bool {
s, hasService := p.Services[service]
_, hasEndpoint := s.Endpoints[region]
if hasEndpoint && hasService {
return true
}
if strictMatch {
return false
}
return p.RegionRegex.MatchString(region)
}
func (p partition) EndpointFor(service, region string, opts ...func(*Options)) (resolved ResolvedEndpoint, err error) {
var opt Options
opt.Set(opts...)
s, hasService := p.Services[service]
if !hasService {
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
}
e, hasEndpoint := s.endpointForRegion(region)
if !hasEndpoint && opt.StrictMatching {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
}
defs := []endpoint{p.Defaults, s.Defaults}
return e.resolve(service, region, p.DNSSuffix, defs, opt), nil
}
func serviceList(ss services) []string {
list := make([]string, 0, len(ss))
for k := range ss {
list = append(list, k)
}
return list
}
func endpointList(es endpoints) []string {
list := make([]string, 0, len(es))
for k := range es {
list = append(list, k)
}
return list
}
type regionRegex struct {
*regexp.Regexp
}
func (rr *regionRegex) UnmarshalJSON(b []byte) (err error) {
// Strip leading and trailing quotes
regex, err := strconv.Unquote(string(b))
if err != nil {
return fmt.Errorf("unable to strip quotes from regex, %v", err)
}
rr.Regexp, err = regexp.Compile(regex)
if err != nil {
return fmt.Errorf("unable to unmarshal region regex, %v", err)
}
return nil
}
type regions map[string]region
type region struct {
Description string `json:"description"`
}
type services map[string]service
type service struct {
PartitionEndpoint string `json:"partitionEndpoint"`
IsRegionalized boxedBool `json:"isRegionalized,omitempty"`
Defaults endpoint `json:"defaults"`
Endpoints endpoints `json:"endpoints"`
}
func (s *service) endpointForRegion(region string) (endpoint, bool) {
if s.IsRegionalized == boxedFalse {
return s.Endpoints[s.PartitionEndpoint], region == s.PartitionEndpoint
}
if e, ok := s.Endpoints[region]; ok {
return e, true
}
// Unable to find any matching endpoint, return
// blank that will be used for generic endpoint creation.
return endpoint{}, false
}
type endpoints map[string]endpoint
type endpoint struct {
Hostname string `json:"hostname"`
Protocols []string `json:"protocols"`
CredentialScope credentialScope `json:"credentialScope"`
// Custom fields not modeled
HasDualStack boxedBool `json:"-"`
DualStackHostname string `json:"-"`
// Signature Version not used
SignatureVersions []string `json:"signatureVersions"`
// SSLCommonName not used.
SSLCommonName string `json:"sslCommonName"`
}
const (
defaultProtocol = "https"
defaultSigner = "v4"
)
var (
protocolPriority = []string{"https", "http"}
signerPriority = []string{"v4", "v2"}
)
func getByPriority(s []string, p []string, def string) string {
if len(s) == 0 {
return def
}
for i := 0; i < len(p); i++ {
for j := 0; j < len(s); j++ {
if s[j] == p[i] {
return s[j]
}
}
}
return s[0]
}
func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, opts Options) ResolvedEndpoint {
var merged endpoint
for _, def := range defs {
merged.mergeIn(def)
}
merged.mergeIn(e)
e = merged
hostname := e.Hostname
// Offset the hostname for dualstack if enabled
if opts.UseDualStack && e.HasDualStack == boxedTrue {
hostname = e.DualStackHostname
}
u := strings.Replace(hostname, "{service}", service, 1)
u = strings.Replace(u, "{region}", region, 1)
u = strings.Replace(u, "{dnsSuffix}", dnsSuffix, 1)
scheme := getEndpointScheme(e.Protocols, opts.DisableSSL)
u = fmt.Sprintf("%s://%s", scheme, u)
signingRegion := e.CredentialScope.Region
if len(signingRegion) == 0 {
signingRegion = region
}
signingName := e.CredentialScope.Service
if len(signingName) == 0 {
signingName = service
}
return ResolvedEndpoint{
URL: u,
SigningRegion: signingRegion,
SigningName: signingName,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
}
}
func getEndpointScheme(protocols []string, disableSSL bool) string {
if disableSSL {
return "http"
}
return getByPriority(protocols, protocolPriority, defaultProtocol)
}
func (e *endpoint) mergeIn(other endpoint) {
if len(other.Hostname) > 0 {
e.Hostname = other.Hostname
}
if len(other.Protocols) > 0 {
e.Protocols = other.Protocols
}
if len(other.SignatureVersions) > 0 {
e.SignatureVersions = other.SignatureVersions
}
if len(other.CredentialScope.Region) > 0 {
e.CredentialScope.Region = other.CredentialScope.Region
}
if len(other.CredentialScope.Service) > 0 {
e.CredentialScope.Service = other.CredentialScope.Service
}
if len(other.SSLCommonName) > 0 {
e.SSLCommonName = other.SSLCommonName
}
if other.HasDualStack != boxedBoolUnset {
e.HasDualStack = other.HasDualStack
}
if len(other.DualStackHostname) > 0 {
e.DualStackHostname = other.DualStackHostname
}
}
type credentialScope struct {
Region string `json:"region"`
Service string `json:"service"`
}
type boxedBool int
func (b *boxedBool) UnmarshalJSON(buf []byte) error {
v, err := strconv.ParseBool(string(buf))
if err != nil {
return err
}
if v {
*b = boxedTrue
} else {
*b = boxedFalse
}
return nil
}
const (
boxedBoolUnset boxedBool = iota
boxedFalse
boxedTrue
)

View File

@@ -0,0 +1,334 @@
// +build codegen
package endpoints
import (
"fmt"
"io"
"reflect"
"strings"
"text/template"
"unicode"
)
// A CodeGenOptions are the options for code generating the endpoints into
// Go code from the endpoints model definition.
type CodeGenOptions struct {
// Options for how the model will be decoded.
DecodeModelOptions DecodeModelOptions
}
// Set combines all of the option functions together
func (d *CodeGenOptions) Set(optFns ...func(*CodeGenOptions)) {
for _, fn := range optFns {
fn(d)
}
}
// CodeGenModel given a endpoints model file will decode it and attempt to
// generate Go code from the model definition. Error will be returned if
// the code is unable to be generated, or decoded.
func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGenOptions)) error {
var opts CodeGenOptions
opts.Set(optFns...)
resolver, err := DecodeModel(modelFile, func(d *DecodeModelOptions) {
*d = opts.DecodeModelOptions
})
if err != nil {
return err
}
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil {
return fmt.Errorf("failed to execute template, %v", err)
}
return nil
}
func toSymbol(v string) string {
out := []rune{}
for _, c := range strings.Title(v) {
if !(unicode.IsNumber(c) || unicode.IsLetter(c)) {
continue
}
out = append(out, c)
}
return string(out)
}
func quoteString(v string) string {
return fmt.Sprintf("%q", v)
}
func regionConstName(p, r string) string {
return toSymbol(p) + toSymbol(r)
}
func partitionGetter(id string) string {
return fmt.Sprintf("%sPartition", toSymbol(id))
}
func partitionVarName(id string) string {
return fmt.Sprintf("%sPartition", strings.ToLower(toSymbol(id)))
}
func listPartitionNames(ps partitions) string {
names := []string{}
switch len(ps) {
case 1:
return ps[0].Name
case 2:
return fmt.Sprintf("%s and %s", ps[0].Name, ps[1].Name)
default:
for i, p := range ps {
if i == len(ps)-1 {
names = append(names, "and "+p.Name)
} else {
names = append(names, p.Name)
}
}
return strings.Join(names, ", ")
}
}
func boxedBoolIfSet(msg string, v boxedBool) string {
switch v {
case boxedTrue:
return fmt.Sprintf(msg, "boxedTrue")
case boxedFalse:
return fmt.Sprintf(msg, "boxedFalse")
default:
return ""
}
}
func stringIfSet(msg, v string) string {
if len(v) == 0 {
return ""
}
return fmt.Sprintf(msg, v)
}
func stringSliceIfSet(msg string, vs []string) string {
if len(vs) == 0 {
return ""
}
names := []string{}
for _, v := range vs {
names = append(names, `"`+v+`"`)
}
return fmt.Sprintf(msg, strings.Join(names, ","))
}
func endpointIsSet(v endpoint) bool {
return !reflect.DeepEqual(v, endpoint{})
}
func serviceSet(ps partitions) map[string]struct{} {
set := map[string]struct{}{}
for _, p := range ps {
for id := range p.Services {
set[id] = struct{}{}
}
}
return set
}
var funcMap = template.FuncMap{
"ToSymbol": toSymbol,
"QuoteString": quoteString,
"RegionConst": regionConstName,
"PartitionGetter": partitionGetter,
"PartitionVarName": partitionVarName,
"ListPartitionNames": listPartitionNames,
"BoxedBoolIfSet": boxedBoolIfSet,
"StringIfSet": stringIfSet,
"StringSliceIfSet": stringSliceIfSet,
"EndpointIsSet": endpointIsSet,
"ServicesSet": serviceSet,
}
const v3Tmpl = `
{{ define "defaults" -}}
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
package endpoints
import (
"regexp"
)
{{ template "partition consts" . }}
{{ range $_, $partition := . }}
{{ template "partition region consts" $partition }}
{{ end }}
{{ template "service consts" . }}
{{ template "endpoint resolvers" . }}
{{- end }}
{{ define "partition consts" }}
// Partition identifiers
const (
{{ range $_, $p := . -}}
{{ ToSymbol $p.ID }}PartitionID = {{ QuoteString $p.ID }} // {{ $p.Name }} partition.
{{ end -}}
)
{{- end }}
{{ define "partition region consts" }}
// {{ .Name }} partition's regions.
const (
{{ range $id, $region := .Regions -}}
{{ ToSymbol $id }}RegionID = {{ QuoteString $id }} // {{ $region.Description }}.
{{ end -}}
)
{{- end }}
{{ define "service consts" }}
// Service identifiers
const (
{{ $serviceSet := ServicesSet . -}}
{{ range $id, $_ := $serviceSet -}}
{{ ToSymbol $id }}ServiceID = {{ QuoteString $id }} // {{ ToSymbol $id }}.
{{ end -}}
)
{{- end }}
{{ define "endpoint resolvers" }}
// DefaultResolver returns an Endpoint resolver that will be able
// to resolve endpoints for: {{ ListPartitionNames . }}.
//
// Casting the return value of this func to a EnumPartitions will
// allow you to get a list of the partitions in the order the endpoints
// will be resolved in.
//
// resolver := endpoints.DefaultResolver()
// partitions := resolver.(endpoints.EnumPartitions).Partitions()
// for _, p := range partitions {
// // ... inspect partitions
// }
func DefaultResolver() Resolver {
return defaultPartitions
}
var defaultPartitions = partitions{
{{ range $_, $partition := . -}}
{{ PartitionVarName $partition.ID }},
{{ end }}
}
{{ range $_, $partition := . -}}
{{ $name := PartitionGetter $partition.ID -}}
// {{ $name }} returns the Resolver for {{ $partition.Name }}.
func {{ $name }}() Partition {
return {{ PartitionVarName $partition.ID }}.Partition()
}
var {{ PartitionVarName $partition.ID }} = {{ template "gocode Partition" $partition }}
{{ end }}
{{ end }}
{{ define "default partitions" }}
func DefaultPartitions() []Partition {
return []partition{
{{ range $_, $partition := . -}}
// {{ ToSymbol $partition.ID}}Partition(),
{{ end }}
}
}
{{ end }}
{{ define "gocode Partition" -}}
partition{
{{ StringIfSet "ID: %q,\n" .ID -}}
{{ StringIfSet "Name: %q,\n" .Name -}}
{{ StringIfSet "DNSSuffix: %q,\n" .DNSSuffix -}}
RegionRegex: {{ template "gocode RegionRegex" .RegionRegex }},
{{ if EndpointIsSet .Defaults -}}
Defaults: {{ template "gocode Endpoint" .Defaults }},
{{- end }}
Regions: {{ template "gocode Regions" .Regions }},
Services: {{ template "gocode Services" .Services }},
}
{{- end }}
{{ define "gocode RegionRegex" -}}
regionRegex{
Regexp: func() *regexp.Regexp{
reg, _ := regexp.Compile({{ QuoteString .Regexp.String }})
return reg
}(),
}
{{- end }}
{{ define "gocode Regions" -}}
regions{
{{ range $id, $region := . -}}
"{{ $id }}": {{ template "gocode Region" $region }},
{{ end -}}
}
{{- end }}
{{ define "gocode Region" -}}
region{
{{ StringIfSet "Description: %q,\n" .Description -}}
}
{{- end }}
{{ define "gocode Services" -}}
services{
{{ range $id, $service := . -}}
"{{ $id }}": {{ template "gocode Service" $service }},
{{ end }}
}
{{- end }}
{{ define "gocode Service" -}}
service{
{{ StringIfSet "PartitionEndpoint: %q,\n" .PartitionEndpoint -}}
{{ BoxedBoolIfSet "IsRegionalized: %s,\n" .IsRegionalized -}}
{{ if EndpointIsSet .Defaults -}}
Defaults: {{ template "gocode Endpoint" .Defaults -}},
{{- end }}
{{ if .Endpoints -}}
Endpoints: {{ template "gocode Endpoints" .Endpoints }},
{{- end }}
}
{{- end }}
{{ define "gocode Endpoints" -}}
endpoints{
{{ range $id, $endpoint := . -}}
"{{ $id }}": {{ template "gocode Endpoint" $endpoint }},
{{ end }}
}
{{- end }}
{{ define "gocode Endpoint" -}}
endpoint{
{{ StringIfSet "Hostname: %q,\n" .Hostname -}}
{{ StringIfSet "SSLCommonName: %q,\n" .SSLCommonName -}}
{{ StringSliceIfSet "Protocols: []string{%s},\n" .Protocols -}}
{{ StringSliceIfSet "SignatureVersions: []string{%s},\n" .SignatureVersions -}}
{{ if or .CredentialScope.Region .CredentialScope.Service -}}
CredentialScope: credentialScope{
{{ StringIfSet "Region: %q,\n" .CredentialScope.Region -}}
{{ StringIfSet "Service: %q,\n" .CredentialScope.Service -}}
},
{{- end }}
{{ BoxedBoolIfSet "HasDualStack: %s,\n" .HasDualStack -}}
{{ StringIfSet "DualStackHostname: %q,\n" .DualStackHostname -}}
}
{{- end }}
`

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/url"
"reflect"
@@ -245,7 +246,82 @@ func (r *Request) ResetBody() {
}
r.safeBody = newOffsetReader(r.Body, r.BodyStart)
r.HTTPRequest.Body = r.safeBody
// Go 1.8 tightened and clarified the rules code needs to use when building
// requests with the http package. Go 1.8 removed the automatic detection
// of if the Request.Body was empty, or actually had bytes in it. The SDK
// always sets the Request.Body even if it is empty and should not actually
// be sent. This is incorrect.
//
// Go 1.8 did add a http.NoBody value that the SDK can use to tell the http
// client that the request really should be sent without a body. The
// Request.Body cannot be set to nil, which is preferable, because the
// field is exported and could introduce nil pointer dereferences for users
// of the SDK if they used that field.
//
// Related golang/go#18257
l, err := computeBodyLength(r.Body)
if err != nil {
r.Error = awserr.New("SerializationError", "failed to compute request body size", err)
return
}
if l == 0 {
r.HTTPRequest.Body = noBodyReader
} else if l > 0 {
r.HTTPRequest.Body = r.safeBody
} else {
// Hack to prevent sending bodies for methods where the body
// should be ignored by the server. Sending bodies on these
// methods without an associated ContentLength will cause the
// request to socket timeout because the server does not handle
// Transfer-Encoding: chunked bodies for these methods.
//
// This would only happen if a aws.ReaderSeekerCloser was used with
// a io.Reader that was not also an io.Seeker.
switch r.Operation.HTTPMethod {
case "GET", "HEAD", "DELETE":
r.HTTPRequest.Body = noBodyReader
default:
r.HTTPRequest.Body = r.safeBody
}
}
}
// Attempts to compute the length of the body of the reader using the
// io.Seeker interface. If the value is not seekable because of being
// a ReaderSeekerCloser without an unerlying Seeker -1 will be returned.
// If no error occurs the length of the body will be returned.
func computeBodyLength(r io.ReadSeeker) (int64, error) {
seekable := true
// Determine if the seeker is actually seekable. ReaderSeekerCloser
// hides the fact that a io.Readers might not actually be seekable.
switch v := r.(type) {
case aws.ReaderSeekerCloser:
seekable = v.IsSeeker()
case *aws.ReaderSeekerCloser:
seekable = v.IsSeeker()
}
if !seekable {
return -1, nil
}
curOffset, err := r.Seek(0, 1)
if err != nil {
return 0, err
}
endOffset, err := r.Seek(0, 2)
if err != nil {
return 0, err
}
_, err = r.Seek(curOffset, 0)
if err != nil {
return 0, err
}
return endOffset - curOffset, nil
}
// GetBody will return an io.ReadSeeker of the Request's underlying
@@ -297,7 +373,7 @@ func (r *Request) Send() error {
r.Handlers.Send.Run(r)
if r.Error != nil {
if strings.Contains(r.Error.Error(), "net/http: request canceled") {
if !shouldRetryCancel(r) {
return r.Error
}
@@ -364,3 +440,26 @@ func AddToUserAgent(r *Request, s string) {
}
r.HTTPRequest.Header.Set("User-Agent", s)
}
func shouldRetryCancel(r *Request) bool {
awsErr, ok := r.Error.(awserr.Error)
timeoutErr := false
errStr := r.Error.Error()
if ok {
err := awsErr.OrigErr()
netErr, netOK := err.(net.Error)
timeoutErr = netOK && netErr.Temporary()
if urlErr, ok := err.(*url.Error); !timeoutErr && ok {
errStr = urlErr.Err.Error()
}
}
// There can be two types of canceled errors here.
// The first being a net.Error and the other being an error.
// If the request was timed out, we want to continue the retry
// process. Otherwise, return the canceled error.
return timeoutErr ||
(errStr != "net/http: request canceled" &&
errStr != "net/http: request canceled while waiting for connection")
}

View File

@@ -0,0 +1,21 @@
// +build !go1.8
package request
import "io"
// NoBody is an io.ReadCloser with no bytes. Read always returns EOF
// and Close always returns nil. It can be used in an outgoing client
// request to explicitly signal that a request has zero bytes.
// An alternative, however, is to simply set Request.Body to nil.
//
// Copy of Go 1.8 NoBody type from net/http/http.go
type noBody struct{}
func (noBody) Read([]byte) (int, error) { return 0, io.EOF }
func (noBody) Close() error { return nil }
func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil }
// Is an empty reader that will trigger the Go HTTP client to not include
// and body in the HTTP request.
var noBodyReader = noBody{}

View File

@@ -0,0 +1,9 @@
// +build go1.8
package request
import "net/http"
// Is a http.NoBody reader instructing Go HTTP client to not include
// and body in the HTTP request.
var noBodyReader = http.NoBody

View File

@@ -10,8 +10,8 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
)
// A Session provides a central location to create service clients from and
@@ -44,7 +44,7 @@ type Session struct {
// shared config, and shared credentials will be taken from the shared
// credentials file.
//
// Deprecated: Use NewSession functiions to create sessions instead. NewSession
// Deprecated: Use NewSession functions to create sessions instead. NewSession
// has the same functionality as New except an error can be returned when the
// func is called instead of waiting to receive an error until a request is made.
func New(cfgs ...*aws.Config) *Session {
@@ -222,6 +222,11 @@ func oldNewSession(cfgs ...*aws.Config) *Session {
// Apply the passed in configs so the configuration can be applied to the
// default credential chain
cfg.MergeIn(cfgs...)
if cfg.EndpointResolver == nil {
// An endpoint resolver is required for a session to be able to provide
// endpoints for service client configurations.
cfg.EndpointResolver = endpoints.DefaultResolver()
}
cfg.Credentials = defaults.CredChain(cfg, handlers)
// Reapply any passed in configs to override credentials if set
@@ -375,19 +380,39 @@ func (s *Session) Copy(cfgs ...*aws.Config) *Session {
// configure the service client instances. Passing the Session to the service
// client's constructor (New) will use this method to configure the client.
func (s *Session) ClientConfig(serviceName string, cfgs ...*aws.Config) client.Config {
// Backwards compatibility, the error will be eaten if user calls ClientConfig
// directly. All SDK services will use ClientconfigWithError.
cfg, _ := s.clientConfigWithErr(serviceName, cfgs...)
return cfg
}
func (s *Session) clientConfigWithErr(serviceName string, cfgs ...*aws.Config) (client.Config, error) {
s = s.Copy(cfgs...)
endpoint, signingRegion := endpoints.NormalizeEndpoint(
aws.StringValue(s.Config.Endpoint),
serviceName,
aws.StringValue(s.Config.Region),
aws.BoolValue(s.Config.DisableSSL),
aws.BoolValue(s.Config.UseDualStack),
)
var resolved endpoints.ResolvedEndpoint
var err error
region := aws.StringValue(s.Config.Region)
if endpoint := aws.StringValue(s.Config.Endpoint); len(endpoint) != 0 {
resolved.URL = endpoints.AddScheme(endpoint, aws.BoolValue(s.Config.DisableSSL))
resolved.SigningRegion = region
} else {
resolved, err = s.Config.EndpointResolver.EndpointFor(
serviceName, region,
func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(s.Config.DisableSSL)
opt.UseDualStack = aws.BoolValue(s.Config.UseDualStack)
},
)
}
return client.Config{
Config: s.Config,
Handlers: s.Handlers,
Endpoint: endpoint,
SigningRegion: signingRegion,
}
Endpoint: resolved.URL,
SigningRegion: resolved.SigningRegion,
SigningName: resolved.SigningName,
}, err
}

View File

@@ -1,24 +0,0 @@
// +build !go1.5
package v4
import (
"net/url"
"strings"
)
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.Path
}
if len(uri) == 0 {
uri = "/"
}
return uri
}

View File

@@ -42,6 +42,14 @@
// the URL.Opaque or URL.RawPath. The SDK will use URL.Opaque first and then
// call URL.EscapedPath() if Opaque is not set.
//
// If signing a request intended for HTTP2 server, and you're using Go 1.6.2
// through 1.7.4 you should use the URL.RawPath as the pre-escaped form of the
// request URL. https://github.com/golang/go/issues/16847 points to a bug in
// Go pre 1.8 that failes to make HTTP2 requests using absolute URL in the HTTP
// message. URL.Opaque generally will force Go to make requests with absolute URL.
// URL.RawPath does not do this, but RawPath must be a valid escaping of Path
// or url.EscapedPath will ignore the RawPath escaping.
//
// Test `TestStandaloneSign` provides a complete example of using the signer
// outside of the SDK and pre-escaping the URI path.
package v4
@@ -171,6 +179,16 @@ type Signer struct {
// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
DisableURIPathEscaping bool
// Disales the automatical setting of the HTTP request's Body field with the
// io.ReadSeeker passed in to the signer. This is useful if you're using a
// custom wrapper around the body for the io.ReadSeeker and want to preserve
// the Body value on the Request.Body.
//
// This does run the risk of signing a request with a body that will not be
// sent in the request. Need to ensure that the underlying data of the Body
// values are the same.
DisableRequestBodyOverwrite bool
// currentTimeFn returns the time value which represents the current time.
// This value should only be used for testing. If it is nil the default
// time.Now will be used.
@@ -321,7 +339,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
// If the request is not presigned the body should be attached to it. This
// prevents the confusion of wanting to send a signed request without
// the body the request was signed for attached.
if !ctx.isPresign {
if !(v4.DisableRequestBodyOverwrite || ctx.isPresign) {
var reader io.ReadCloser
if body != nil {
var ok bool
@@ -416,6 +434,10 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
// S3 service should not have any escaping applied
v4.DisableURIPathEscaping = true
}
// Prevents setting the HTTPRequest's Body. Since the Body could be
// wrapped in a custom io.Closer that we do not want to be stompped
// on top of by the signer.
v4.DisableRequestBodyOverwrite = true
})
signingTime := req.Time

View File

@@ -5,7 +5,13 @@ import (
"sync"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Should
// only be used with an io.Reader that is also an io.Seeker. Doing so may
// cause request signature errors, or request body's not sent for GET, HEAD
// and DELETE HTTP methods.
//
// Deprecated: Should only be used with io.ReadSeeker. If using for
// S3 PutObject to stream content use s3manager.Uploader instead.
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
@@ -44,6 +50,12 @@ func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
return int64(0), nil
}
// IsSeeker returns if the underlying reader is also a seeker.
func (r ReaderSeekerCloser) IsSeeker() bool {
_, ok := r.r.(io.Seeker)
return ok
}
// Close closes the ReaderSeekerCloser.
//
// If the ReaderSeekerCloser is not an io.Closer nothing will be done.

View File

@@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "1.5.12"
const SDKVersion = "1.6.2"

View File

@@ -1,70 +0,0 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run -tags codegen ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import (
"fmt"
"regexp"
"strings"
)
// NormalizeEndpoint takes and endpoint and service API information to return a
// normalized endpoint and signing region. If the endpoint is not an empty string
// the service name and region will be used to look up the service's API endpoint.
// If the endpoint is provided the scheme will be added if it is not present.
func NormalizeEndpoint(endpoint, serviceName, region string, disableSSL, useDualStack bool) (normEndpoint, signingRegion string) {
if endpoint == "" {
return EndpointForRegion(serviceName, region, disableSSL, useDualStack)
}
return AddScheme(endpoint, disableSSL), ""
}
// EndpointForRegion returns an endpoint and its signing region for a service and region.
// if the service and region pair are not found endpoint and signingRegion will be empty.
func EndpointForRegion(svcName, region string, disableSSL, useDualStack bool) (endpoint, signingRegion string) {
dualStackField := ""
if useDualStack {
dualStackField = "/dualstack"
}
derivedKeys := []string{
region + "/" + svcName + dualStackField,
region + "/*" + dualStackField,
"*/" + svcName + dualStackField,
"*/*" + dualStackField,
}
for _, key := range derivedKeys {
if val, ok := endpointsMap.Endpoints[key]; ok {
ep := val.Endpoint
ep = strings.Replace(ep, "{region}", region, -1)
ep = strings.Replace(ep, "{service}", svcName, -1)
endpoint = ep
signingRegion = val.SigningRegion
break
}
}
return AddScheme(endpoint, disableSSL), signingRegion
}
// Regular expression to determine if the endpoint string is prefixed with a scheme.
var schemeRE = regexp.MustCompile("^([^:]+)://")
// AddScheme adds the HTTP or HTTPS schemes to a endpoint URL if there is no
// scheme. If disableSSL is true HTTP will be added instead of the default HTTPS.
func AddScheme(endpoint string, disableSSL bool) string {
if endpoint != "" && !schemeRE.MatchString(endpoint) {
scheme := "https"
if disableSSL {
scheme = "http"
}
endpoint = fmt.Sprintf("%s://%s", scheme, endpoint)
}
return endpoint
}

View File

@@ -1,82 +0,0 @@
{
"version": 2,
"endpoints": {
"*/*": {
"endpoint": "{service}.{region}.amazonaws.com"
},
"cn-north-1/*": {
"endpoint": "{service}.{region}.amazonaws.com.cn",
"signatureVersion": "v4"
},
"cn-north-1/ec2metadata": {
"endpoint": "http://169.254.169.254/latest"
},
"us-gov-west-1/iam": {
"endpoint": "iam.us-gov.amazonaws.com"
},
"us-gov-west-1/sts": {
"endpoint": "sts.us-gov-west-1.amazonaws.com"
},
"us-gov-west-1/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"us-gov-west-1/ec2metadata": {
"endpoint": "http://169.254.169.254/latest"
},
"*/budgets": {
"endpoint": "budgets.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudfront": {
"endpoint": "cloudfront.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/cloudsearchdomain": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/data.iot": {
"endpoint": "",
"signingRegion": "us-east-1"
},
"*/ec2metadata": {
"endpoint": "http://169.254.169.254/latest"
},
"*/iam": {
"endpoint": "iam.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/importexport": {
"endpoint": "importexport.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/route53": {
"endpoint": "route53.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/sts": {
"endpoint": "sts.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/waf": {
"endpoint": "waf.amazonaws.com",
"signingRegion": "us-east-1"
},
"us-east-1/sdb": {
"endpoint": "sdb.amazonaws.com",
"signingRegion": "us-east-1"
},
"*/s3": {
"endpoint": "s3-{region}.amazonaws.com"
},
"*/s3/dualstack": {
"endpoint": "s3.dualstack.{region}.amazonaws.com"
},
"us-east-1/s3": {
"endpoint": "s3.amazonaws.com"
},
"eu-central-1/s3": {
"endpoint": "{service}.{region}.amazonaws.com"
}
}
}

View File

@@ -1,95 +0,0 @@
package endpoints
// THIS FILE IS AUTOMATICALLY GENERATED. DO NOT EDIT.
type endpointStruct struct {
Version int
Endpoints map[string]endpointEntry
}
type endpointEntry struct {
Endpoint string
SigningRegion string
}
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/budgets": {
Endpoint: "budgets.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/data.iot": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/ec2metadata": {
Endpoint: "http://169.254.169.254/latest",
},
"*/iam": {
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": {
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": {
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"*/s3/dualstack": {
Endpoint: "s3.dualstack.{region}.amazonaws.com",
},
"*/sts": {
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/waf": {
Endpoint: "waf.amazonaws.com",
SigningRegion: "us-east-1",
},
"cn-north-1/*": {
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"cn-north-1/ec2metadata": {
Endpoint: "http://169.254.169.254/latest",
},
"eu-central-1/s3": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"us-east-1/s3": {
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": {
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/ec2metadata": {
Endpoint: "http://169.254.169.254/latest",
},
"us-gov-west-1/iam": {
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": {
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
},
}

View File

@@ -65,6 +65,11 @@ func BuildAsGET(r *request.Request) {
func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bool) {
query := r.HTTPRequest.URL.Query()
// Setup the raw path to match the base path pattern. This is needed
// so that when the path is mutated a custom escaped version can be
// stored in RawPath that will be used by the Go client.
r.HTTPRequest.URL.RawPath = r.HTTPRequest.URL.Path
for i := 0; i < v.NumField(); i++ {
m := v.Field(i)
if n := v.Type().Field(i).Name; n[0:1] == strings.ToLower(n[0:1]) {
@@ -107,7 +112,9 @@ func buildLocationElements(r *request.Request, v reflect.Value, buildGETQuery bo
}
r.HTTPRequest.URL.RawQuery = query.Encode()
updatePath(r.HTTPRequest.URL, r.HTTPRequest.URL.Path, aws.BoolValue(r.Config.DisableRestProtocolURICleaning))
if !aws.BoolValue(r.Config.DisableRestProtocolURICleaning) {
cleanPath(r.HTTPRequest.URL)
}
}
func buildBody(r *request.Request, v reflect.Value) {
@@ -171,10 +178,11 @@ func buildURI(u *url.URL, v reflect.Value, name string) error {
return awserr.New("SerializationError", "failed to encode REST request", err)
}
uri := u.Path
uri = strings.Replace(uri, "{"+name+"}", EscapePath(value, true), -1)
uri = strings.Replace(uri, "{"+name+"+}", EscapePath(value, false), -1)
u.Path = uri
u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1)
u.Path = strings.Replace(u.Path, "{"+name+"+}", value, -1)
u.RawPath = strings.Replace(u.RawPath, "{"+name+"}", EscapePath(value, true), -1)
u.RawPath = strings.Replace(u.RawPath, "{"+name+"+}", EscapePath(value, false), -1)
return nil
}
@@ -208,27 +216,17 @@ func buildQueryString(query url.Values, v reflect.Value, name string) error {
return nil
}
func updatePath(url *url.URL, urlPath string, disableRestProtocolURICleaning bool) {
scheme, query := url.Scheme, url.RawQuery
func cleanPath(u *url.URL) {
hasSlash := strings.HasSuffix(u.Path, "/")
hasSlash := strings.HasSuffix(urlPath, "/")
// clean up path, removing duplicate `/`
u.Path = path.Clean(u.Path)
u.RawPath = path.Clean(u.RawPath)
// clean up path
if !disableRestProtocolURICleaning {
urlPath = path.Clean(urlPath)
if hasSlash && !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
u.RawPath += "/"
}
if hasSlash && !strings.HasSuffix(urlPath, "/") {
urlPath += "/"
}
// get formatted URL minus scheme so we can build this into Opaque
url.Scheme, url.Path, url.RawQuery = "", "", ""
s := url.String()
url.Scheme = scheme
url.RawQuery = query
// build opaque URI
url.Opaque = s + urlPath
}
// EscapePath escapes part of a URL path in Amazon style

View File

@@ -1,6 +1,7 @@
package rest
import (
"bytes"
"encoding/base64"
"fmt"
"io"
@@ -11,7 +12,6 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
@@ -70,10 +70,16 @@ func unmarshalBody(r *request.Request, v reflect.Value) {
}
default:
switch payload.Type().String() {
case "io.ReadSeeker":
payload.Set(reflect.ValueOf(aws.ReadSeekCloser(r.HTTPResponse.Body)))
case "aws.ReadSeekCloser", "io.ReadCloser":
case "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
case "io.ReadSeeker":
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = awserr.New("SerializationError",
"failed to read response body", err)
return
}
payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
default:
io.Copy(ioutil.Discard, r.HTTPResponse.Body)
defer r.HTTPResponse.Body.Close()

View File

@@ -111,11 +111,8 @@ func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
elems := node.Children[name]
if elems == nil { // try to find the field in attributes
for _, a := range node.Attr {
if name == strings.Join([]string{a.Name.Space, a.Name.Local}, ":") {
// turn this into a text node for de-serializing
elems = []*XMLNode{{Text: a.Value}}
}
if val, ok := node.findElem(name); ok {
elems = []*XMLNode{{Text: val}}
}
}

View File

@@ -2,6 +2,7 @@ package xmlutil
import (
"encoding/xml"
"fmt"
"io"
"sort"
)
@@ -12,6 +13,9 @@ type XMLNode struct {
Children map[string][]*XMLNode `json:",omitempty"`
Text string `json:",omitempty"`
Attr []xml.Attr `json:",omitempty"`
namespaces map[string]string
parent *XMLNode
}
// NewXMLElement returns a pointer to a new XMLNode initialized to default values.
@@ -59,41 +63,52 @@ func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) {
slice = []*XMLNode{}
}
node, e := XMLToStruct(d, &el)
out.findNamespaces()
if e != nil {
return out, e
}
node.Name = typed.Name
node.Attr = out.Attr
node = adaptNode(node)
node.findNamespaces()
tempOut := *out
// Save into a temp variable, simply because out gets squashed during
// loop iterations
node.parent = &tempOut
slice = append(slice, node)
out.Children[name] = slice
case xml.EndElement:
if s != nil && s.Name.Local == typed.Name.Local { // matching end token
return out, nil
}
out = &XMLNode{}
}
}
return out, nil
}
func adaptNode(node *XMLNode) *XMLNode {
func (n *XMLNode) findNamespaces() {
ns := map[string]string{}
for _, a := range node.Attr {
for _, a := range n.Attr {
if a.Name.Space == "xmlns" {
ns[a.Value] = a.Name.Local
break
}
}
for i, a := range node.Attr {
if a.Name.Space == "xmlns" {
continue
}
if v, ok := ns[node.Attr[i].Name.Space]; ok {
node.Attr[i].Name.Space = v
n.namespaces = ns
}
func (n *XMLNode) findElem(name string) (string, bool) {
for node := n; node != nil; node = node.parent {
for _, a := range node.Attr {
namespace := a.Name.Space
if v, ok := node.namespaces[namespace]; ok {
namespace = v
}
if name == fmt.Sprintf("%s:%s", namespace, a.Name.Local) {
return a.Value, true
}
}
}
return node
return "", false
}
// StructToXML writes an XMLNode to a xml.Encoder as tokens.

View File

@@ -151,16 +151,17 @@ const ServiceName = "dynamodb"
// svc := dynamodb.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *DynamoDB {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *DynamoDB {
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *DynamoDB {
svc := &DynamoDB{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2012-08-10",

File diff suppressed because it is too large Load Diff

View File

@@ -5,8 +5,8 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/private/endpoints"
)
func init() {
@@ -39,12 +39,20 @@ func fillPresignedURL(r *request.Request) {
WithRegion(aws.StringValue(origParams.SourceRegion)))
clientInfo := r.ClientInfo
clientInfo.Endpoint, clientInfo.SigningRegion = endpoints.EndpointForRegion(
clientInfo.ServiceName,
aws.StringValue(cfg.Region),
aws.BoolValue(cfg.DisableSSL),
aws.BoolValue(cfg.UseDualStack),
resolved, err := r.Config.EndpointResolver.EndpointFor(
clientInfo.ServiceName, aws.StringValue(cfg.Region),
func(opt *endpoints.Options) {
opt.DisableSSL = aws.BoolValue(cfg.DisableSSL)
opt.UseDualStack = aws.BoolValue(cfg.UseDualStack)
},
)
if err != nil {
r.Error = err
return
}
clientInfo.Endpoint = resolved.URL
clientInfo.SigningRegion = resolved.SigningRegion
// Presign a CopySnapshot request with modified params
req := request.New(*cfg, clientInfo, r.Handlers, r.Retryer, r.Operation, newParams, r.Data)

View File

@@ -42,19 +42,20 @@ const ServiceName = "ec2"
// svc := ec2.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2 {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *EC2 {
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *EC2 {
svc := &EC2{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2016-09-15",
APIVersion: "2016-11-15",
},
handlers,
),

View File

@@ -91,16 +91,17 @@ const ServiceName = "iam"
// svc := iam.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *IAM {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *IAM {
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *IAM {
svc := &IAM{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2010-05-08",

View File

@@ -9740,8 +9740,7 @@ type GetObjectTaggingInput struct {
// Key is a required field
Key *string `location:"uri" locationName:"Key" min:"1" type:"string" required:"true"`
// VersionId is a required field
VersionId *string `location:"uri" locationName:"VersionId" type:"string" required:"true"`
VersionId *string `location:"querystring" locationName:"versionId" type:"string"`
}
// String returns the string representation
@@ -9766,9 +9765,6 @@ func (s *GetObjectTaggingInput) Validate() error {
if s.Key != nil && len(*s.Key) < 1 {
invalidParams.Add(request.NewErrParamMinLen("Key", 1))
}
if s.VersionId == nil {
invalidParams.Add(request.NewErrParamRequired("VersionId"))
}
if invalidParams.Len() > 0 {
return invalidParams
@@ -15453,8 +15449,7 @@ type PutObjectTaggingInput struct {
// Tagging is a required field
Tagging *Tagging `locationName:"Tagging" type:"structure" required:"true"`
// VersionId is a required field
VersionId *string `location:"uri" locationName:"VersionId" type:"string" required:"true"`
VersionId *string `location:"querystring" locationName:"versionId" type:"string"`
}
// String returns the string representation
@@ -15482,9 +15477,6 @@ func (s *PutObjectTaggingInput) Validate() error {
if s.Tagging == nil {
invalidParams.Add(request.NewErrParamRequired("Tagging"))
}
if s.VersionId == nil {
invalidParams.Add(request.NewErrParamRequired("VersionId"))
}
if s.Tagging != nil {
if err := s.Tagging.Validate(); err != nil {
invalidParams.AddNested("Tagging", err.(request.ErrInvalidParams))

View File

@@ -1,7 +1,6 @@
package s3
import (
"bytes"
"fmt"
"net/url"
"regexp"
@@ -88,24 +87,26 @@ func updateEndpointForAccelerate(r *request.Request) {
return
}
// Change endpoint from s3(-[a-z0-1-])?.amazonaws.com to s3-accelerate.amazonaws.com
r.HTTPRequest.URL.Host = replaceHostRegion(r.HTTPRequest.URL.Host, "accelerate")
parts := strings.Split(r.HTTPRequest.URL.Host, ".")
if len(parts) < 3 {
r.Error = awserr.New("InvalidParameterExecption",
fmt.Sprintf("unable to update endpoint host for S3 accelerate, hostname invalid, %s",
r.HTTPRequest.URL.Host), nil)
return
}
if aws.BoolValue(r.Config.UseDualStack) {
host := []byte(r.HTTPRequest.URL.Host)
// Strip region from hostname
if idx := bytes.Index(host, accelElem); idx >= 0 {
start := idx + len(accelElem)
if end := bytes.IndexByte(host[start:], '.'); end >= 0 {
end += start + 1
copy(host[start:], host[end:])
host = host[:len(host)-(end-start)]
r.HTTPRequest.URL.Host = string(host)
}
if parts[0] == "s3" || strings.HasPrefix(parts[0], "s3-") {
parts[0] = "s3-accelerate"
}
for i := 1; i+1 < len(parts); i++ {
if parts[i] == aws.StringValue(r.Config.Region) {
parts = append(parts[:i], parts[i+1:]...)
break
}
}
r.HTTPRequest.URL.Host = strings.Join(parts, ".")
moveBucketToHost(r.HTTPRequest.URL, bucket)
}
@@ -159,28 +160,3 @@ func moveBucketToHost(u *url.URL, bucket string) {
u.Path = "/"
}
}
const s3HostPrefix = "s3"
// replaceHostRegion replaces the S3 region string in the host with the
// value provided. If v is empty the host prefix returned will be s3.
func replaceHostRegion(host, v string) string {
if !strings.HasPrefix(host, s3HostPrefix) {
return host
}
suffix := host[len(s3HostPrefix):]
for i := len(s3HostPrefix); i < len(host); i++ {
if host[i] == '.' {
// Trim until '.' leave the it in place.
suffix = host[i:]
break
}
}
if len(v) == 0 {
return fmt.Sprintf("s3%s", suffix)
}
return fmt.Sprintf("s3-%s%s", v, suffix)
}

View File

@@ -39,16 +39,17 @@ const ServiceName = "s3"
// svc := s3.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *S3 {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *S3 {
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *S3 {
svc := &S3{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2006-03-01",

View File

@@ -5,7 +5,6 @@ import (
"io/ioutil"
"net/http"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
@@ -17,8 +16,8 @@ func copyMultipartStatusOKUnmarhsalError(r *request.Request) {
return
}
body := bytes.NewReader(b)
r.HTTPResponse.Body = aws.ReadSeekCloser(body)
defer r.HTTPResponse.Body.(aws.ReaderSeekerCloser).Seek(0, 0)
r.HTTPResponse.Body = ioutil.NopCloser(body)
defer body.Seek(0, 0)
if body.Len() == 0 {
// If there is no body don't attempt to parse the body.

View File

@@ -447,7 +447,7 @@ func (c *STS) AssumeRoleWithWebIdentityRequest(input *AssumeRoleWithWebIdentityI
// For more information about how to use web identity federation and the AssumeRoleWithWebIdentity
// API, see the following resources:
//
// * Using Web Identity Federation APIs for Mobile Apps (http://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_oidc_manual)
// * Using Web Identity Federation APIs for Mobile Apps (http://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_providers_oidc_manual.html)
// and Federation Through a Web-based Identity Provider (http://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_request.html#api_assumerolewithwebidentity).
//
//
@@ -761,7 +761,7 @@ func (c *STS) GetFederationTokenRequest(input *GetFederationTokenInput) (req *re
//
// * You cannot use these credentials to call any IAM APIs.
//
// * You cannot call any STS APIs.
// * You cannot call any STS APIs except GetCallerIdentity.
//
// Permissions
//
@@ -904,7 +904,7 @@ func (c *STS) GetSessionTokenRequest(input *GetSessionTokenInput) (req *request.
// * You cannot call any IAM APIs unless MFA authentication information is
// included in the request.
//
// * You cannot call any STS API exceptAssumeRole.
// * You cannot call any STS API exceptAssumeRole or GetCallerIdentity.
//
// We recommend that you do not call GetSessionToken with root account credentials.
// Instead, follow our best practices (http://docs.aws.amazon.com/IAM/latest/UserGuide/best-practices.html#create-iam-users)

View File

@@ -83,16 +83,17 @@ const ServiceName = "sts"
// svc := sts.New(mySession, aws.NewConfig().WithRegion("us-west-2"))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *STS {
c := p.ClientConfig(ServiceName, cfgs...)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
return newClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion, c.SigningName)
}
// newClient creates, initializes and returns a new service client instance.
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion string) *STS {
func newClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegion, signingName string) *STS {
svc := &STS{
Client: client.New(
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
SigningName: signingName,
SigningRegion: signingRegion,
Endpoint: endpoint,
APIVersion: "2011-06-15",

View File

@@ -0,0 +1,107 @@
## Circonus gometrics options
### Example defaults
```go
package main
import (
"fmt"
"io/ioutil"
"log"
"os"
"path"
cgm "github.com/circonus-labs/circonus-gometrics"
)
func main() {
cfg := &cgm.Config{}
// Defaults
// General
cfg.Debug = false
cfg.Log = log.New(ioutil.Discard, "", log.LstdFlags)
cfg.Interval = "10s"
cfg.ResetCounters = "true"
cfg.ResetGauges = "true"
cfg.ResetHistograms = "true"
cfg.ResetText = "true"
// API
cfg.CheckManager.API.TokenKey = ""
cfg.CheckManager.API.TokenApp = "circonus-gometrics"
cfg.CheckManager.API.TokenURL = "https://api.circonus.com/v2"
// Check
_, an := path.Split(os.Args[0])
hn, _ := os.Hostname()
cfg.CheckManager.Check.ID = ""
cfg.CheckManager.Check.SubmissionURL = ""
cfg.CheckManager.Check.InstanceID = fmt.Sprintf("%s:%s", hn, an)
cfg.CheckManager.Check.TargetHost = cfg.CheckManager.Check.InstanceID
cfg.CheckManager.Check.DisplayName = cfg.CheckManager.Check.InstanceID
cfg.CheckManager.Check.SearchTag = fmt.Sprintf("service:%s", an)
cfg.CheckManager.Check.Tags = ""
cfg.CheckManager.Check.Secret = "" // randomly generated sha256 hash
cfg.CheckManager.Check.MaxURLAge = "5m"
cfg.CheckManager.Check.ForceMetricActivation = "false"
// Broker
cfg.CheckManager.Broker.ID = ""
cfg.CheckManager.Broker.SelectTag = ""
cfg.CheckManager.Broker.MaxResponseTime = "500ms"
// create a new cgm instance and start sending metrics...
// see the complete example in the main README.
}
```
## Options
| Option | Default | Description |
| ------ | ------- | ----------- |
| General ||
| `cfg.Log` | none | log.Logger instance to send logging messages. Default is to discard messages. If Debug is turned on and no instance is specified, messages will go to stderr. |
| `cfg.Debug` | false | Turn on debugging messages. |
| `cfg.Interval` | "10s" | Interval at which metrics are flushed and sent to Circonus. |
| `cfg.ResetCounters` | "true" | Reset counter metrics after each submission. Change to "false" to retain (and continue submitting) the last value.|
| `cfg.ResetGauges` | "true" | Reset gauge metrics after each submission. Change to "false" to retain (and continue submitting) the last value.|
| `cfg.ResetHistograms` | "true" | Reset histogram metrics after each submission. Change to "false" to retain (and continue submitting) the last value.|
| `cfg.ResetText` | "true" | Reset text metrics after each submission. Change to "false" to retain (and continue submitting) the last value.|
|API||
| `cfg.CheckManager.API.TokenKey` | "" | [Circonus API Token key](https://login.circonus.com/user/tokens) |
| `cfg.CheckManager.API.TokenApp` | "circonus-gometrics" | App associated with API token |
| `cfg.CheckManager.API.URL` | "https://api.circonus.com/v2" | Circonus API URL |
|Check||
| `cfg.CheckManager.Check.ID` | "" | Check ID of previously created check. (*Note: **check id** not **check bundle id**.*) |
| `cfg.CheckManager.Check.SubmissionURL` | "" | Submission URL of previously created check. |
| `cfg.CheckManager.Check.InstanceID` | hostname:program name | An identifier for the 'group of metrics emitted by this process or service'. |
| `cfg.CheckManager.Check.TargetHost` | InstanceID | Explicit setting of `check.target`. |
| `cfg.CheckManager.Check.DisplayName` | InstanceID | Custom `check.display_name`. Shows in UI check list. |
| `cfg.CheckManager.Check.SearchTag` | service:program name | Specific tag used to search for an existing check when neither SubmissionURL nor ID are provided. |
| `cfg.CheckManager.Check.Tags` | "" | List (comma separated) of tags to add to check when it is being created. The SearchTag will be added to the list. |
| `cfg.CheckManager.Check.Secret` | random generated | A secret to use for when creating an httptrap check. |
| `cfg.CheckManager.Check.MaxURLAge` | "5m" | Maximum amount of time to retry a [failing] submission URL before refreshing it. |
| `cfg.CheckManager.Check.ForceMetricActivation` | "false" | If a metric has been disabled via the UI the default behavior is to *not* re-activate the metric; this setting overrides the behavior and will re-activate the metric when it is encountered. |
|Broker||
| `cfg.CheckManager.Broker.ID` | "" | ID of a specific broker to use when creating a check. Default is to use a random enterprise broker or the public Circonus default broker. |
| `cfg.CheckManager.Broker.SelectTag` | "" | Used to select a broker with the same tag(s). If more than one broker has the tag(s), one will be selected randomly from the resulting list. (e.g. could be used to select one from a list of brokers serving a specific colo/region. "dc:sfo", "loc:nyc,dc:nyc01", "zone:us-west") |
| `cfg.CheckManager.Broker.MaxResponseTime` | "500ms" | Maximum amount time to wait for a broker connection test to be considered valid. (if latency is > the broker will be considered invalid and not available for selection.) |
## Notes:
* All options are *strings* with the following exceptions:
* `cfg.Log` - an instance of [`log.Logger`](https://golang.org/pkg/log/#Logger) or something else (e.g. [logrus](https://github.com/Sirupsen/logrus)) which can be used to satisfy the interface requirements.
* `cfg.Debug` - a boolean true|false.
* At a minimum, one of either `API.TokenKey` or `Check.SubmissionURL` is **required** for cgm to function.
* Check management can be disabled by providing a `Check.SubmissionURL` without an `API.TokenKey`. Note: the supplied URL needs to be http or the broker needs to be running with a cert which can be verified. Otherwise, the `API.TokenKey` will be required to retrieve the correct CA certificate to validate the broker's cert for the SSL connection.
* A note on `Check.InstanceID`, the instance id is used to consistently identify a check. The display name can be changed in the UI. The hostname may be ephemeral. For metric continuity, the instance id is used to locate existing checks. Since the check.target is never actually used by an httptrap check it is more decorative than functional, a valid FQDN is not required for an httptrap check.target. But, using instance id as the target can pollute the Host list in the UI with host:application specific entries.
* Check identification precedence
1. Check SubmissionURL
2. Check ID
3. Search
1. Search for an active httptrap check for TargetHost which has the SearchTag
2. Search for an active httptrap check which has the SearchTag and the InstanceID in the notes field
3. Create a new check
* Broker selection
1. If Broker.ID or Broker.SelectTag are not specified, a broker will be selected randomly from the list of brokers available to the API token. Enterprise brokers take precedence. A viable broker is "active", has the "httptrap" module enabled, and responds within Broker.MaxResponseTime.

View File

@@ -1,14 +1,18 @@
# Circonus metrics tracking for Go applications
This library supports named counters, gauges and histograms.
It also provides convenience wrappers for registering latency
instrumented functions with Go's builtin http server.
This library supports named counters, gauges and histograms. It also provides convenience wrappers for registering latency instrumented functions with Go's builtin http server.
Initializing only requires setting an ApiToken.
Initializing only requires setting an [API Token](https://login.circonus.com/user/tokens) at a minimum.
## Options
See [OPTIONS.md](OPTIONS.md) for information on all of the available cgm options.
## Example
**rough and simple**
### Bare bones minimum
A working cut-n-past example. Simply set the required environment variable `CIRCONUS_API_TOKEN` and run.
```go
package main
@@ -17,6 +21,83 @@ import (
"log"
"math/rand"
"os"
"os/signal"
"syscall"
"time"
cgm "github.com/circonus-labs/circonus-gometrics"
)
func main() {
logger := log.New(os.Stdout, "", log.LstdFlags)
logger.Println("Configuring cgm")
cmc := &cgm.Config{}
cmc.Debug := false // set to true for debug messages
cmc.Log = logger
// Circonus API Token key (https://login.circonus.com/user/tokens)
cmc.CheckManager.API.TokenKey = os.Getenv("CIRCONUS_API_TOKEN")
logger.Println("Creating new cgm instance")
metrics, err := cgm.NewCirconusMetrics(cmc)
if err != nil {
logger.Println(err)
os.Exit(1)
}
src := rand.NewSource(time.Now().UnixNano())
rnd := rand.New(src)
logger.Println("Starting cgm internal auto-flush timer")
metrics.Start()
logger.Println("Adding ctrl-c trap")
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
logger.Println("Received CTRL-C, flushing outstanding metrics before exit")
metrics.Flush()
os.Exit(0)
}()
logger.Println("Starting to send metrics")
// number of "sets" of metrics to send
max := 60
for i := 1; i < max; i++ {
logger.Printf("\tmetric set %d of %d", i, 60)
metrics.Timing("foo", rnd.Float64()*10)
metrics.Increment("bar")
metrics.Gauge("baz", 10)
time.Sleep(time.Second)
}
metrics.SetText("fini", "complete")
logger.Println("Flushing any outstanding metrics manually")
metrics.Flush()
}
```
### A more complete example
A working, cut-n-paste example with all options available for modification. Also, demonstrates metric tagging.
```go
package main
import (
"log"
"math/rand"
"os"
"os/signal"
"syscall"
"time"
cgm "github.com/circonus-labs/circonus-gometrics"
@@ -30,127 +111,47 @@ func main() {
cmc := &cgm.Config{}
// Interval at which metrics are submitted to Circonus, default: 10 seconds
// cmc.Interval = "10s" // 10 seconds
// General
// Enable debug messages, default: false
cmc.Debug = true
// Send debug messages to specific log.Logger instance
// default: if debug stderr, else, discard
cmc.Log = logger
// Reset counter metrics after each submission, default: "true"
// Change to "false" to retain (and continue submitting) the last value.
// cmc.ResetCounters = "true"
// Reset gauge metrics after each submission, default: "true"
// Change to "false" to retain (and continue submitting) the last value.
// cmc.ResetGauges = "true"
// Reset histogram metrics after each submission, default: "true"
// Change to "false" to retain (and continue submitting) the last value.
// cmc.ResetHistograms = "true"
// Reset text metrics after each submission, default: "true"
// Change to "false" to retain (and continue submitting) the last value.
// cmc.ResetText = "true"
cmc.Interval = "10s"
cmc.Log = logger
cmc.Debug = false
cmc.ResetCounters = "true"
cmc.ResetGauges = "true"
cmc.ResetHistograms = "true"
cmc.ResetText = "true"
// Circonus API configuration options
//
// Token, no default (blank disables check manager)
cmc.CheckManager.API.TokenKey = os.Getenv("CIRCONUS_API_TOKEN")
// App name, default: circonus-gometrics
cmc.CheckManager.API.TokenApp = os.Getenv("CIRCONUS_API_APP")
// URL, default: https://api.circonus.com/v2
cmc.CheckManager.API.URL = os.Getenv("CIRCONUS_API_URL")
// Check configuration options
//
// precedence 1 - explicit submission_url
// precedence 2 - specific check id (note: not a check bundle id)
// precedence 3 - search using instanceId and searchTag
// otherwise: if an applicable check is NOT specified or found, an
// attempt will be made to automatically create one
//
// Submission URL for an existing [httptrap] check
cmc.CheckManager.Check.SubmissionURL = os.Getenv("CIRCONUS_SUBMISION_URL")
// ID of an existing [httptrap] check (note: check id not check bundle id)
cmc.CheckManager.Check.ID = os.Getenv("CIRCONUS_CHECK_ID")
// if neither a submission url nor check id are provided, an attempt will be made to find an existing
// httptrap check by using the circonus api to search for a check matching the following criteria:
// an active check,
// of type httptrap,
// where the target/host is equal to InstanceId - see below
// and the check has a tag equal to SearchTag - see below
// Instance ID - an identifier for the 'group of metrics emitted by this process or service'
// this is used as the value for check.target (aka host)
// default: 'hostname':'program name'
// note: for a persistent instance that is ephemeral or transient where metric continuity is
// desired set this explicitly so that the current hostname will not be used.
// cmc.CheckManager.Check.InstanceID = ""
// Search tag - specific tag(s) used in conjunction with isntanceId to search for an
// existing check. comma separated string of tags (spaces will be removed, no commas
// in tag elements).
// default: service:application name (e.g. service:consul service:nomad etc.)
// cmc.CheckManager.Check.SearchTag = ""
// Check secret, default: generated when a check needs to be created
// cmc.CheckManager.Check.Secret = ""
// Additional tag(s) to add when *creating* a check. comma separated string
// of tags (spaces will be removed, no commas in tag elements).
// (e.g. group:abc or service_role:agent,group:xyz).
// default: none
// cmc.CheckManager.Check.Tags = ""
// max amount of time to to hold on to a submission url
// when a given submission fails (due to retries) if the
// time the url was last updated is > than this, the trap
// url will be refreshed (e.g. if the broker is changed
// in the UI) default 5 minutes
// cmc.CheckManager.Check.MaxURLAge = "5m"
// custom display name for check, default: "InstanceId /cgm"
// cmc.CheckManager.Check.DisplayName = ""
// force metric activation - if a metric has been disabled via the UI
// the default behavior is to *not* re-activate the metric; this setting
// overrides the behavior and will re-activate the metric when it is
// encountered. "(true|false)", default "false"
// cmc.CheckManager.Check.ForceMetricActivation = "false"
cmc.CheckManager.Check.InstanceID = ""
cmc.CheckManager.Check.DisplayName = ""
cmc.CheckManager.Check.TargetHost = ""
// if hn, err := os.Hostname(); err == nil {
// cmc.CheckManager.Check.TargetHost = hn
// }
cmc.CheckManager.Check.SearchTag = ""
cmc.CheckManager.Check.Secret = ""
cmc.CheckManager.Check.Tags = ""
cmc.CheckManager.Check.MaxURLAge = "5m"
cmc.CheckManager.Check.ForceMetricActivation = "false"
// Broker configuration options
//
// Broker ID of specific broker to use, default: random enterprise broker or
// Circonus default if no enterprise brokers are available.
// default: only used if set
// cmc.CheckManager.Broker.ID = ""
// used to select a broker with the same tag(s) (e.g. can be used to dictate that a broker
// serving a specific location should be used. "dc:sfo", "loc:nyc,dc:nyc01", "zone:us-west")
// if more than one broker has the tag(s), one will be selected randomly from the resulting
// list. comma separated string of tags (spaces will be removed, no commas in tag elements).
// default: none
// cmc.CheckManager.Broker.SelectTag = ""
// longest time to wait for a broker connection (if latency is > the broker will
// be considered invalid and not available for selection.), default: 500 milliseconds
// cmc.CheckManager.Broker.MaxResponseTime = "500ms"
// note: if broker Id or SelectTag are not specified, a broker will be selected randomly
// from the list of brokers available to the api token. enterprise brokers take precedence
// viable brokers are "active", have the "httptrap" module enabled, are reachable and respond
// within MaxResponseTime.
cmc.CheckManager.Broker.ID = ""
cmc.CheckManager.Broker.SelectTag = ""
cmc.CheckManager.Broker.MaxResponseTime = "500ms"
logger.Println("Creating new cgm instance")
metrics, err := cgm.NewCirconusMetrics(cmc)
if err != nil {
panic(err)
logger.Println(err)
os.Exit(1)
}
src := rand.NewSource(time.Now().UnixNano())
@@ -201,13 +202,13 @@ func main() {
### HTTP Handler wrapping
```
```go
http.HandleFunc("/", metrics.TrackHTTPLatency("/", handler_func))
```
### HTTP latency example
```
```go
package main
import (

View File

@@ -97,6 +97,7 @@ func NewAPI(ac *Config) (*API, error) {
au = fmt.Sprintf("https://%s/v2", ac.URL)
}
if last := len(au) - 1; last >= 0 && au[last] == '/' {
// strip off trailing '/'
au = au[:last]
}
apiURL, err := url.Parse(au)
@@ -143,10 +144,13 @@ func (a *API) apiCall(reqMethod string, reqPath string, data []byte) ([]byte, er
dataReader := bytes.NewReader(data)
reqURL := a.apiURL.String()
if reqPath == "" {
return nil, errors.New("Invalid URL path")
}
if reqPath[:1] != "/" {
reqURL += "/"
}
if reqPath[:3] == "/v2" {
if len(reqPath) >= 3 && reqPath[:3] == "/v2" {
reqURL += reqPath[3:len(reqPath)]
} else {
reqURL += reqPath
@@ -172,7 +176,9 @@ func (a *API) apiCall(reqMethod string, reqPath string, data []byte) ([]byte, er
// errors and may relate to outages on the server side. This will catch
// invalid response codes as well, like 0 and 999.
// Retry on 429 (rate limit) as well.
if resp.StatusCode == 0 || resp.StatusCode >= 500 || resp.StatusCode == 429 {
if resp.StatusCode == 0 || // wtf?!
resp.StatusCode >= 500 || // rutroh
resp.StatusCode == 429 { // rate limit
body, readErr := ioutil.ReadAll(resp.Body)
if readErr != nil {
lastHTTPError = fmt.Errorf("- last HTTP error: %d %+v", resp.StatusCode, readErr)

View File

@@ -7,6 +7,8 @@ package api
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strings"
)
@@ -26,7 +28,7 @@ type BrokerDetail struct {
// Broker definition
type Broker struct {
Cid string `json:"_cid"`
CID string `json:"_cid"`
Details []BrokerDetail `json:"_details"`
Latitude string `json:"_latitude"`
Longitude string `json:"_longitude"`
@@ -35,15 +37,27 @@ type Broker struct {
Type string `json:"_type"`
}
var baseBrokerPath = "/broker"
// FetchBrokerByID fetch a broker configuration by [group]id
func (a *API) FetchBrokerByID(id IDType) (*Broker, error) {
cid := CIDType(fmt.Sprintf("/broker/%d", id))
cid := CIDType(fmt.Sprintf("%s/%d", baseBrokerPath, id))
return a.FetchBrokerByCID(cid)
}
// FetchBrokerByCID fetch a broker configuration by cid
func (a *API) FetchBrokerByCID(cid CIDType) (*Broker, error) {
result, err := a.Get(string(cid))
if matched, err := regexp.MatchString("^"+baseBrokerPath+"/[0-9]+$", string(cid)); err != nil {
return nil, err
} else if !matched {
return nil, fmt.Errorf("Invalid broker CID %v", cid)
}
reqURL := url.URL{
Path: string(cid),
}
result, err := a.Get(reqURL.String())
if err != nil {
return nil, err
}

View File

@@ -6,8 +6,10 @@ package api
import (
"encoding/json"
"errors"
"fmt"
"net/url"
"regexp"
"strings"
)
@@ -18,22 +20,30 @@ type CheckDetails struct {
// Check definition
type Check struct {
Cid string `json:"_cid"`
CID string `json:"_cid"`
Active bool `json:"_active"`
BrokerCid string `json:"_broker"`
CheckBundleCid string `json:"_check_bundle"`
BrokerCID string `json:"_broker"`
CheckBundleCID string `json:"_check_bundle"`
CheckUUID string `json:"_check_uuid"`
Details CheckDetails `json:"_details"`
}
var baseCheckPath = "/check"
// FetchCheckByID fetch a check configuration by id
func (a *API) FetchCheckByID(id IDType) (*Check, error) {
cid := CIDType(fmt.Sprintf("/check/%d", int(id)))
cid := CIDType(fmt.Sprintf("%s/%d", baseCheckPath, int(id)))
return a.FetchCheckByCID(cid)
}
// FetchCheckByCID fetch a check configuration by cid
func (a *API) FetchCheckByCID(cid CIDType) (*Check, error) {
if matched, err := regexp.MatchString("^"+baseCheckPath+"/[0-9]+$", string(cid)); err != nil {
return nil, err
} else if !matched {
return nil, fmt.Errorf("Invalid check CID %v", cid)
}
result, err := a.Get(string(cid))
if err != nil {
return nil, err
@@ -49,6 +59,9 @@ func (a *API) FetchCheckByCID(cid CIDType) (*Check, error) {
// FetchCheckBySubmissionURL fetch a check configuration by submission_url
func (a *API) FetchCheckBySubmissionURL(submissionURL URLType) (*Check, error) {
if string(submissionURL) == "" {
return nil, errors.New("[ERROR] Invalid submission URL (blank)")
}
u, err := url.Parse(string(submissionURL))
if err != nil {
@@ -99,10 +112,18 @@ func (a *API) FetchCheckBySubmissionURL(submissionURL URLType) (*Check, error) {
}
// CheckSearch returns a list of checks matching a search query
func (a *API) CheckSearch(query SearchQueryType) ([]Check, error) {
queryURL := fmt.Sprintf("/check?search=%s", string(query))
func (a *API) CheckSearch(searchCriteria SearchQueryType) ([]Check, error) {
reqURL := url.URL{
Path: baseCheckPath,
}
result, err := a.Get(queryURL)
if searchCriteria != "" {
q := url.Values{}
q.Set("search", string(searchCriteria))
reqURL.RawQuery = q.Encode()
}
result, err := a.Get(reqURL.String())
if err != nil {
return nil, err
}
@@ -115,8 +136,13 @@ func (a *API) CheckSearch(query SearchQueryType) ([]Check, error) {
return checks, nil
}
// CheckFilterSearch returns a list of checks matching a filter
// CheckFilterSearch returns a list of checks matching a filter (filtering allows looking for
// things within sub-elements e.g. details)
func (a *API) CheckFilterSearch(filter SearchFilterType) ([]Check, error) {
if filter == "" {
return nil, errors.New("[ERROR] invalid filter supplied (blank)")
}
filterURL := fmt.Sprintf("/check?%s", string(filter))
result, err := a.Get(filterURL)

View File

@@ -7,6 +7,8 @@ package api
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
)
// CheckBundleConfig configuration specific to check type
@@ -36,7 +38,7 @@ type CheckBundleMetric struct {
type CheckBundle struct {
CheckUUIDs []string `json:"_check_uuids,omitempty"`
Checks []string `json:"_checks,omitempty"`
Cid string `json:"_cid,omitempty"`
CID string `json:"_cid,omitempty"`
Created int `json:"_created,omitempty"`
LastModified int `json:"_last_modified,omitempty"`
LastModifedBy string `json:"_last_modifed_by,omitempty"`
@@ -55,15 +57,27 @@ type CheckBundle struct {
Type string `json:"type"`
}
var baseCheckBundlePath = "/check_bundle"
// FetchCheckBundleByID fetch a check bundle configuration by id
func (a *API) FetchCheckBundleByID(id IDType) (*CheckBundle, error) {
cid := CIDType(fmt.Sprintf("/check_bundle/%d", id))
cid := CIDType(fmt.Sprintf("%s/%d", baseCheckBundlePath, id))
return a.FetchCheckBundleByCID(cid)
}
// FetchCheckBundleByCID fetch a check bundle configuration by id
func (a *API) FetchCheckBundleByCID(cid CIDType) (*CheckBundle, error) {
result, err := a.Get(string(cid))
if matched, err := regexp.MatchString("^"+baseCheckBundlePath+"/[0-9]+$", string(cid)); err != nil {
return nil, err
} else if !matched {
return nil, fmt.Errorf("Invalid check bundle CID %v", cid)
}
reqURL := url.URL{
Path: string(cid),
}
result, err := a.Get(reqURL.String())
if err != nil {
return nil, err
}
@@ -77,17 +91,55 @@ func (a *API) FetchCheckBundleByCID(cid CIDType) (*CheckBundle, error) {
}
// CheckBundleSearch returns list of check bundles matching a search query
// - a search query not a filter (see: https://login.circonus.com/resources/api#searching)
// - a search query (see: https://login.circonus.com/resources/api#searching)
func (a *API) CheckBundleSearch(searchCriteria SearchQueryType) ([]CheckBundle, error) {
apiPath := fmt.Sprintf("/check_bundle?search=%s", searchCriteria)
reqURL := url.URL{
Path: baseCheckBundlePath,
}
response, err := a.Get(apiPath)
if searchCriteria != "" {
q := url.Values{}
q.Set("search", string(searchCriteria))
reqURL.RawQuery = q.Encode()
}
resp, err := a.Get(reqURL.String())
if err != nil {
return nil, fmt.Errorf("[ERROR] API call error %+v", err)
}
var results []CheckBundle
if err := json.Unmarshal(response, &results); err != nil {
if err := json.Unmarshal(resp, &results); err != nil {
return nil, err
}
return results, nil
}
// CheckBundleFilterSearch returns list of check bundles matching a search query and filter
// - a search query (see: https://login.circonus.com/resources/api#searching)
// - a filter (see: https://login.circonus.com/resources/api#filtering)
func (a *API) CheckBundleFilterSearch(searchCriteria SearchQueryType, filterCriteria map[string]string) ([]CheckBundle, error) {
reqURL := url.URL{
Path: baseCheckBundlePath,
}
if searchCriteria != "" {
q := url.Values{}
q.Set("search", string(searchCriteria))
for field, val := range filterCriteria {
q.Set(field, val)
}
reqURL.RawQuery = q.Encode()
}
resp, err := a.Get(reqURL.String())
if err != nil {
return nil, fmt.Errorf("[ERROR] API call error %+v", err)
}
var results []CheckBundle
if err := json.Unmarshal(resp, &results); err != nil {
return nil, err
}
@@ -95,19 +147,23 @@ func (a *API) CheckBundleSearch(searchCriteria SearchQueryType) ([]CheckBundle,
}
// CreateCheckBundle create a new check bundle (check)
func (a *API) CreateCheckBundle(config CheckBundle) (*CheckBundle, error) {
cfgJSON, err := json.Marshal(config)
func (a *API) CreateCheckBundle(config *CheckBundle) (*CheckBundle, error) {
reqURL := url.URL{
Path: baseCheckBundlePath,
}
cfg, err := json.Marshal(config)
if err != nil {
return nil, err
}
response, err := a.Post("/check_bundle", cfgJSON)
resp, err := a.Post(reqURL.String(), cfg)
if err != nil {
return nil, err
}
checkBundle := &CheckBundle{}
if err := json.Unmarshal(response, checkBundle); err != nil {
if err := json.Unmarshal(resp, checkBundle); err != nil {
return nil, err
}
@@ -116,22 +172,28 @@ func (a *API) CreateCheckBundle(config CheckBundle) (*CheckBundle, error) {
// UpdateCheckBundle updates a check bundle configuration
func (a *API) UpdateCheckBundle(config *CheckBundle) (*CheckBundle, error) {
if a.Debug {
a.Log.Printf("[DEBUG] Updating check bundle.")
if matched, err := regexp.MatchString("^"+baseCheckBundlePath+"/[0-9]+$", string(config.CID)); err != nil {
return nil, err
} else if !matched {
return nil, fmt.Errorf("Invalid check bundle CID %v", config.CID)
}
cfgJSON, err := json.Marshal(config)
reqURL := url.URL{
Path: config.CID,
}
cfg, err := json.Marshal(config)
if err != nil {
return nil, err
}
response, err := a.Put(config.Cid, cfgJSON)
resp, err := a.Put(reqURL.String(), cfg)
if err != nil {
return nil, err
}
checkBundle := &CheckBundle{}
if err := json.Unmarshal(response, checkBundle); err != nil {
if err := json.Unmarshal(resp, checkBundle); err != nil {
return nil, err
}

View File

@@ -0,0 +1,159 @@
// Copyright 2016 Circonus, Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
)
// MetricQuery object
type MetricQuery struct {
Query string `json:"query"`
Type string `json:"type"`
}
// MetricCluster object
type MetricCluster struct {
CID string `json:"_cid,omitempty"`
MatchingMetrics []string `json:"_matching_metrics,omitempty"`
MatchingUUIDMetrics map[string][]string `json:"_matching_uuid_metrics,omitempty"`
Description string `json:"description"`
Name string `json:"name"`
Queries []MetricQuery `json:"queries"`
Tags []string `json:"tags"`
}
var baseMetricClusterPath = "/metric_cluster"
// FetchMetricClusterByID fetch a metric cluster configuration by id
func (a *API) FetchMetricClusterByID(id IDType, extras string) (*MetricCluster, error) {
reqURL := url.URL{
Path: fmt.Sprintf("%s/%d", baseMetricClusterPath, id),
}
cid := CIDType(reqURL.String())
return a.FetchMetricClusterByCID(cid, extras)
}
// FetchMetricClusterByCID fetch a check bundle configuration by id
func (a *API) FetchMetricClusterByCID(cid CIDType, extras string) (*MetricCluster, error) {
if matched, err := regexp.MatchString("^"+baseMetricClusterPath+"/[0-9]+$", string(cid)); err != nil {
return nil, err
} else if !matched {
return nil, fmt.Errorf("Invalid metric cluster CID %v", cid)
}
reqURL := url.URL{
Path: string(cid),
}
extra := ""
switch extras {
case "metrics":
extra = "_matching_metrics"
case "uuids":
extra = "_matching_uuid_metrics"
}
if extra != "" {
q := url.Values{}
q.Set("extra", extra)
reqURL.RawQuery = q.Encode()
}
resp, err := a.Get(reqURL.String())
if err != nil {
return nil, err
}
cluster := &MetricCluster{}
if err := json.Unmarshal(resp, cluster); err != nil {
return nil, err
}
return cluster, nil
}
// MetricClusterSearch returns list of metric clusters matching a search query (or all metric
// clusters if no search query is provided)
// - a search query not a filter (see: https://login.circonus.com/resources/api#searching)
func (a *API) MetricClusterSearch(searchCriteria SearchQueryType) ([]MetricCluster, error) {
reqURL := url.URL{
Path: baseMetricClusterPath,
}
if searchCriteria != "" {
q := url.Values{}
q.Set("search", string(searchCriteria))
reqURL.RawQuery = q.Encode()
}
resp, err := a.Get(reqURL.String())
if err != nil {
return nil, fmt.Errorf("[ERROR] API call error %+v", err)
}
var clusters []MetricCluster
if err := json.Unmarshal(resp, &clusters); err != nil {
return nil, err
}
return clusters, nil
}
// CreateMetricCluster create a new metric cluster
func (a *API) CreateMetricCluster(config *MetricCluster) (*MetricCluster, error) {
reqURL := url.URL{
Path: baseMetricClusterPath,
}
cfg, err := json.Marshal(config)
if err != nil {
return nil, err
}
resp, err := a.Post(reqURL.String(), cfg)
if err != nil {
return nil, err
}
cluster := &MetricCluster{}
if err := json.Unmarshal(resp, cluster); err != nil {
return nil, err
}
return cluster, nil
}
// UpdateMetricCluster updates a metric cluster
func (a *API) UpdateMetricCluster(config *MetricCluster) (*MetricCluster, error) {
if matched, err := regexp.MatchString("^"+baseMetricClusterPath+"/[0-9]+$", string(config.CID)); err != nil {
return nil, err
} else if !matched {
return nil, fmt.Errorf("Invalid metric cluster CID %v", config.CID)
}
reqURL := url.URL{
Path: config.CID,
}
cfg, err := json.Marshal(config)
if err != nil {
return nil, err
}
resp, err := a.Put(reqURL.String(), cfg)
if err != nil {
return nil, err
}
cluster := &MetricCluster{}
if err := json.Unmarshal(resp, cluster); err != nil {
return nil, err
}
return cluster, nil
}

View File

@@ -101,7 +101,7 @@ func (cm *CheckManager) selectBroker() (*api.Broker, error) {
for _, broker := range brokerList {
if cm.isValidBroker(&broker) {
validBrokers[broker.Cid] = broker
validBrokers[broker.CID] = broker
if broker.Type == "enterprise" {
haveEnterprise = true
}

View File

@@ -7,6 +7,7 @@ package checkmgr
import (
"crypto/x509"
"encoding/json"
"errors"
"fmt"
)
@@ -65,7 +66,7 @@ func (cm *CheckManager) loadCACert() {
// fetchCert fetches CA certificate using Circonus API
func (cm *CheckManager) fetchCert() ([]byte, error) {
if !cm.enabled {
return circonusCA, nil
return nil, errors.New("check manager is not enabled")
}
response, err := cm.apih.Get("/pki/ca.crt")

View File

@@ -35,7 +35,7 @@ func (cm *CheckManager) UpdateCheck(newMetrics map[string]*api.CheckBundleMetric
}
// refresh check bundle (in case there were changes made by other apps or in UI)
checkBundle, err := cm.apih.FetchCheckBundleByCID(api.CIDType(cm.checkBundle.Cid))
checkBundle, err := cm.apih.FetchCheckBundleByCID(api.CIDType(cm.checkBundle.CID))
if err != nil {
cm.Log.Printf("[ERROR] unable to fetch up-to-date check bundle %v", err)
return
@@ -105,7 +105,7 @@ func (cm *CheckManager) initializeTrapURL() error {
}
if !cm.enabled {
return errors.New("Unable to initialize trap, check manager is disabled.")
return errors.New("unable to initialize trap, check manager is disabled")
}
var err error
@@ -119,7 +119,7 @@ func (cm *CheckManager) initializeTrapURL() error {
return err
}
if !check.Active {
return fmt.Errorf("[ERROR] Check ID %v is not active", check.Cid)
return fmt.Errorf("[ERROR] Check ID %v is not active", check.CID)
}
// extract check id from check object returned from looking up using submission url
// set m.CheckId to the id
@@ -128,14 +128,14 @@ func (cm *CheckManager) initializeTrapURL() error {
// unless the new submission url can be fetched with the API (which is no
// longer possible using the original submission url)
var id int
id, err = strconv.Atoi(strings.Replace(check.Cid, "/check/", "", -1))
id, err = strconv.Atoi(strings.Replace(check.CID, "/check/", "", -1))
if err == nil {
cm.checkID = api.IDType(id)
cm.checkSubmissionURL = ""
} else {
cm.Log.Printf(
"[WARN] SubmissionUrl check to Check ID: unable to convert %s to int %q\n",
check.Cid, err)
check.CID, err)
}
} else if cm.checkID > 0 {
check, err = cm.apih.FetchCheckByID(cm.checkID)
@@ -143,15 +143,28 @@ func (cm *CheckManager) initializeTrapURL() error {
return err
}
if !check.Active {
return fmt.Errorf("[ERROR] Check ID %v is not active", check.Cid)
return fmt.Errorf("[ERROR] Check ID %v is not active", check.CID)
}
} else {
searchCriteria := fmt.Sprintf(
"(active:1)(host:\"%s\")(type:\"%s\")(tags:%s)(notes:%s)",
cm.checkTarget, cm.checkType, strings.Join(cm.checkSearchTag, ","), fmt.Sprintf("cgm_instanceid=%s", cm.checkInstanceID))
checkBundle, err = cm.checkBundleSearch(searchCriteria)
if err != nil {
return err
if checkBundle == nil {
// old search (instanceid as check.target)
searchCriteria := fmt.Sprintf(
"(active:1)(type:\"%s\")(host:\"%s\")(tags:%s)", cm.checkType, cm.checkTarget, strings.Join(cm.checkSearchTag, ","))
checkBundle, err = cm.checkBundleSearch(searchCriteria, map[string]string{})
if err != nil {
return err
}
}
if checkBundle == nil {
// new search (check.target != instanceid, instanceid encoded in notes field)
searchCriteria := fmt.Sprintf(
"(active:1)(type:\"%s\")(tags:%s)", cm.checkType, strings.Join(cm.checkSearchTag, ","))
filterCriteria := map[string]string{"f_notes": cm.getNotes()}
checkBundle, err = cm.checkBundleSearch(searchCriteria, filterCriteria)
if err != nil {
return err
}
}
if checkBundle == nil {
@@ -166,7 +179,7 @@ func (cm *CheckManager) initializeTrapURL() error {
if checkBundle == nil {
if check != nil {
checkBundle, err = cm.apih.FetchCheckBundleByCID(api.CIDType(check.CheckBundleCid))
checkBundle, err = cm.apih.FetchCheckBundleByCID(api.CIDType(check.CheckBundleCID))
if err != nil {
return err
}
@@ -214,8 +227,8 @@ func (cm *CheckManager) initializeTrapURL() error {
}
// Search for a check bundle given a predetermined set of criteria
func (cm *CheckManager) checkBundleSearch(criteria string) (*api.CheckBundle, error) {
checkBundles, err := cm.apih.CheckBundleSearch(api.SearchQueryType(criteria))
func (cm *CheckManager) checkBundleSearch(criteria string, filter map[string]string) (*api.CheckBundle, error) {
checkBundles, err := cm.apih.CheckBundleFilterSearch(api.SearchQueryType(criteria), filter)
if err != nil {
return nil, err
}
@@ -235,7 +248,7 @@ func (cm *CheckManager) checkBundleSearch(criteria string) (*api.CheckBundle, er
}
if numActive > 1 {
return nil, fmt.Errorf("[ERROR] Multiple possibilities multiple check bundles match criteria %s\n", criteria)
return nil, fmt.Errorf("[ERROR] multiple check bundles match criteria %s", criteria)
}
return &checkBundles[checkID], nil
@@ -257,17 +270,17 @@ func (cm *CheckManager) createNewCheck() (*api.CheckBundle, *api.Broker, error)
return nil, nil, err
}
config := api.CheckBundle{
Brokers: []string{broker.Cid},
config := &api.CheckBundle{
Brokers: []string{broker.CID},
Config: api.CheckBundleConfig{AsyncMetrics: true, Secret: checkSecret},
DisplayName: string(cm.checkDisplayName),
Metrics: []api.CheckBundleMetric{},
MetricLimit: 0,
Notes: fmt.Sprintf("cgm_instanceid=%s", cm.checkInstanceID),
Notes: cm.getNotes(),
Period: 60,
Status: statusActive,
Tags: append(cm.checkSearchTag, cm.checkTags...),
Target: cm.checkTarget,
Target: string(cm.checkTarget),
Timeout: 10,
Type: string(cm.checkType),
}
@@ -290,3 +303,7 @@ func (cm *CheckManager) makeSecret() (string, error) {
hash.Write(x)
return hex.EncodeToString(hash.Sum(nil))[0:16], nil
}
func (cm *CheckManager) getNotes() string {
return fmt.Sprintf("cgm_instanceid|%s", cm.checkInstanceID)
}

View File

@@ -59,12 +59,15 @@ type CheckConfig struct {
// used to search for a check to use
// used as check.target when creating a check
InstanceID string
// explicitly set check.target (default: instance id)
TargetHost string
// a custom display name for the check (as viewed in UI Checks)
// default: instance id
DisplayName string
// unique check searching tag (or tags)
// used to search for a check to use (combined with instanceid)
// used as a regular tag when creating a check
SearchTag string
// a custom display name for the check (as viewed in UI Checks)
DisplayName string
// httptrap check secret (for creating a check)
Secret string
// additional tags to add to a check (when creating a check)
@@ -115,6 +118,9 @@ type CheckTypeType string
// CheckInstanceIDType check instance id
type CheckInstanceIDType string
// CheckTargetType check target/host
type CheckTargetType string
// CheckSecretType check secret
type CheckSecretType string
@@ -138,7 +144,7 @@ type CheckManager struct {
checkType CheckTypeType
checkID api.IDType
checkInstanceID CheckInstanceIDType
checkTarget string
checkTarget CheckTargetType
checkSearchTag api.TagType
checkSecret CheckSecretType
checkTags api.TagType
@@ -178,7 +184,7 @@ type Trap struct {
func NewCheckManager(cfg *Config) (*CheckManager, error) {
if cfg == nil {
return nil, errors.New("Invalid Check Manager configuration (nil).")
return nil, errors.New("invalid Check Manager configuration (nil)")
}
cm := &CheckManager{
@@ -200,7 +206,7 @@ func NewCheckManager(cfg *Config) (*CheckManager, error) {
// Blank API Token *disables* check management
if cfg.API.TokenKey == "" {
if cm.checkSubmissionURL == "" {
return nil, errors.New("Invalid check manager configuration (no API token AND no submission url).")
return nil, errors.New("invalid check manager configuration (no API token AND no submission url)")
}
if err := cm.initializeTrapURL(); err != nil {
return nil, err
@@ -238,6 +244,7 @@ func NewCheckManager(cfg *Config) (*CheckManager, error) {
cm.checkID = api.IDType(id)
cm.checkInstanceID = CheckInstanceIDType(cfg.Check.InstanceID)
cm.checkTarget = CheckTargetType(cfg.Check.TargetHost)
cm.checkDisplayName = CheckDisplayNameType(cfg.Check.DisplayName)
cm.checkSecret = CheckSecretType(cfg.Check.Secret)
@@ -259,7 +266,12 @@ func NewCheckManager(cfg *Config) (*CheckManager, error) {
if cm.checkInstanceID == "" {
cm.checkInstanceID = CheckInstanceIDType(fmt.Sprintf("%s:%s", hn, an))
}
cm.checkTarget = hn
if cm.checkDisplayName == "" {
cm.checkDisplayName = CheckDisplayNameType(cm.checkInstanceID)
}
if cm.checkTarget == "" {
cm.checkTarget = CheckTargetType(cm.checkInstanceID)
}
if cfg.Check.SearchTag == "" {
cm.checkSearchTag = []string{fmt.Sprintf("service:%s", an)}
@@ -271,10 +283,6 @@ func NewCheckManager(cfg *Config) (*CheckManager, error) {
cm.checkTags = strings.Split(strings.Replace(cfg.Check.Tags, " ", "", -1), ",")
}
if cm.checkDisplayName == "" {
cm.checkDisplayName = CheckDisplayNameType(fmt.Sprintf("%s", string(cm.checkInstanceID)))
}
dur := cfg.Check.MaxURLAge
if dur == "" {
dur = defaultTrapMaxURLAge

View File

@@ -101,7 +101,7 @@ type CirconusMetrics struct {
func NewCirconusMetrics(cfg *Config) (*CirconusMetrics, error) {
if cfg == nil {
return nil, errors.New("Invalid configuration (nil).")
return nil, errors.New("invalid configuration (nil)")
}
cm := &CirconusMetrics{
@@ -114,66 +114,84 @@ func NewCirconusMetrics(cfg *Config) (*CirconusMetrics, error) {
textFuncs: make(map[string]func() string),
}
cm.Debug = cfg.Debug
cm.Log = cfg.Log
// Logging
{
cm.Debug = cfg.Debug
cm.Log = cfg.Log
if cm.Debug && cfg.Log == nil {
cm.Log = log.New(os.Stderr, "", log.LstdFlags)
}
if cm.Log == nil {
cm.Log = log.New(ioutil.Discard, "", log.LstdFlags)
if cm.Debug && cfg.Log == nil {
cm.Log = log.New(os.Stderr, "", log.LstdFlags)
}
if cm.Log == nil {
cm.Log = log.New(ioutil.Discard, "", log.LstdFlags)
}
}
fi := defaultFlushInterval
if cfg.Interval != "" {
fi = cfg.Interval
// Flush Interval
{
fi := defaultFlushInterval
if cfg.Interval != "" {
fi = cfg.Interval
}
dur, err := time.ParseDuration(fi)
if err != nil {
return nil, err
}
cm.flushInterval = dur
}
dur, err := time.ParseDuration(fi)
if err != nil {
return nil, err
}
cm.flushInterval = dur
var setting bool
// metric resets
cm.resetCounters = true
if cfg.ResetCounters != "" {
if setting, err = strconv.ParseBool(cfg.ResetCounters); err == nil {
cm.resetCounters = setting
setting, err := strconv.ParseBool(cfg.ResetCounters)
if err != nil {
return nil, err
}
cm.resetCounters = setting
}
cm.resetGauges = true
if cfg.ResetGauges != "" {
if setting, err = strconv.ParseBool(cfg.ResetGauges); err == nil {
cm.resetGauges = setting
setting, err := strconv.ParseBool(cfg.ResetGauges)
if err != nil {
return nil, err
}
cm.resetGauges = setting
}
cm.resetHistograms = true
if cfg.ResetHistograms != "" {
if setting, err = strconv.ParseBool(cfg.ResetHistograms); err == nil {
cm.resetHistograms = setting
setting, err := strconv.ParseBool(cfg.ResetHistograms)
if err != nil {
return nil, err
}
cm.resetHistograms = setting
}
cm.resetText = true
if cfg.ResetText != "" {
if setting, err = strconv.ParseBool(cfg.ResetText); err == nil {
cm.resetText = setting
setting, err := strconv.ParseBool(cfg.ResetText)
if err != nil {
return nil, err
}
cm.resetText = setting
}
cfg.CheckManager.Debug = cm.Debug
cfg.CheckManager.Log = cm.Log
// check manager
{
cfg.CheckManager.Debug = cm.Debug
cfg.CheckManager.Log = cm.Log
check, err := checkmgr.NewCheckManager(&cfg.CheckManager)
if err != nil {
return nil, err
check, err := checkmgr.NewCheckManager(&cfg.CheckManager)
if err != nil {
return nil, err
}
cm.check = check
}
cm.check = check
// initialize
if _, err := cm.check.GetTrap(); err != nil {
return nil, err
}

View File

@@ -42,80 +42,99 @@ func (m *CirconusMetrics) Reset() {
// snapshot returns a copy of the values of all registered counters and gauges.
func (m *CirconusMetrics) snapshot() (c map[string]uint64, g map[string]string, h map[string]*circonusllhist.Histogram, t map[string]string) {
c = m.snapCounters()
g = m.snapGauges()
h = m.snapHistograms()
t = m.snapText()
return
}
func (m *CirconusMetrics) snapCounters() map[string]uint64 {
c := make(map[string]uint64, len(m.counters)+len(m.counterFuncs))
m.cm.Lock()
defer m.cm.Unlock()
m.cfm.Lock()
defer m.cfm.Unlock()
m.gm.Lock()
defer m.gm.Unlock()
m.gfm.Lock()
defer m.gfm.Unlock()
m.hm.Lock()
defer m.hm.Unlock()
m.tm.Lock()
defer m.tm.Unlock()
m.tfm.Lock()
defer m.tfm.Unlock()
c = make(map[string]uint64, len(m.counters)+len(m.counterFuncs))
for n, v := range m.counters {
c[n] = v
}
if m.resetCounters && len(c) > 0 {
m.counters = make(map[string]uint64)
}
m.cm.Unlock()
m.cfm.Lock()
for n, f := range m.counterFuncs {
c[n] = f()
}
if m.resetCounters && len(c) > 0 {
m.counterFuncs = make(map[string]func() uint64)
}
m.cfm.Unlock()
//g = make(map[string]int64, len(m.gauges)+len(m.gaugeFuncs))
g = make(map[string]string, len(m.gauges)+len(m.gaugeFuncs))
return c
}
func (m *CirconusMetrics) snapGauges() map[string]string {
g := make(map[string]string, len(m.gauges)+len(m.gaugeFuncs))
m.gm.Lock()
for n, v := range m.gauges {
g[n] = v
}
if m.resetGauges && len(g) > 0 {
m.gauges = make(map[string]string)
}
m.gm.Unlock()
m.gfm.Lock()
for n, f := range m.gaugeFuncs {
g[n] = m.gaugeValString(f())
}
if m.resetGauges && len(g) > 0 {
m.gaugeFuncs = make(map[string]func() int64)
}
m.gfm.Unlock()
h = make(map[string]*circonusllhist.Histogram, len(m.histograms))
return g
}
func (m *CirconusMetrics) snapHistograms() map[string]*circonusllhist.Histogram {
h := make(map[string]*circonusllhist.Histogram, len(m.histograms))
m.hm.Lock()
for n, hist := range m.histograms {
hist.rw.Lock()
h[n] = hist.hist.CopyAndReset()
hist.rw.Unlock()
}
if m.resetHistograms && len(h) > 0 {
m.histograms = make(map[string]*Histogram)
}
m.hm.Unlock()
t = make(map[string]string, len(m.text)+len(m.textFuncs))
return h
}
func (m *CirconusMetrics) snapText() map[string]string {
t := make(map[string]string, len(m.text)+len(m.textFuncs))
m.tm.Lock()
for n, v := range m.text {
t[n] = v
}
if m.resetText && len(t) > 0 {
m.text = make(map[string]string)
}
m.tm.Unlock()
m.tfm.Lock()
for n, f := range m.textFuncs {
t[n] = f()
}
if m.resetCounters {
m.counters = make(map[string]uint64)
m.counterFuncs = make(map[string]func() uint64)
}
if m.resetGauges {
m.gauges = make(map[string]string)
m.gaugeFuncs = make(map[string]func() int64)
}
if m.resetHistograms {
m.histograms = make(map[string]*Histogram)
}
if m.resetText {
m.text = make(map[string]string)
if m.resetText && len(t) > 0 {
m.textFuncs = make(map[string]func() string)
}
m.tfm.Unlock()
return
return t
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,8 +6,10 @@ import (
"errors"
)
type packetType uint8
type header struct {
PacketType uint8
PacketType packetType
Status uint8
Size uint16
Spid uint16
@@ -21,7 +23,7 @@ type tdsBuffer struct {
transport io.ReadWriteCloser
size uint16
final bool
packet_type uint8
packet_type packetType
afterFirst func()
}
@@ -84,8 +86,8 @@ func (w *tdsBuffer) WriteByte(b byte) error {
return nil
}
func (w *tdsBuffer) BeginPacket(packet_type byte) {
w.buf[0] = packet_type
func (w *tdsBuffer) BeginPacket(packet_type packetType) {
w.buf[0] = byte(packet_type)
w.buf[1] = 0 // packet is incomplete
w.buf[4] = 0 // spid
w.buf[5] = 0
@@ -124,7 +126,7 @@ func (r *tdsBuffer) readNextPacket() error {
return nil
}
func (r *tdsBuffer) BeginRead() (uint8, error) {
func (r *tdsBuffer) BeginRead() (packetType, error) {
err := r.readNextPacket()
if err != nil {
return 0, err

View File

@@ -4,19 +4,26 @@ import (
"log"
)
type Logger log.Logger
type Logger interface {
Printf(format string, v ...interface{})
Println(v ...interface{})
}
func (logger *Logger) Printf(format string, v ...interface{}) {
if logger != nil {
(*log.Logger)(logger).Printf(format, v...)
type optionalLogger struct {
logger Logger
}
func (o optionalLogger) Printf(format string, v ...interface{}) {
if o.logger != nil {
o.logger.Printf(format, v...)
} else {
log.Printf(format, v...)
}
}
func (logger *Logger) Println(v ...interface{}) {
if logger != nil {
(*log.Logger)(logger).Println(v...)
func (o optionalLogger) Println(v ...interface{}) {
if o.logger != nil {
o.logger.Println(v...)
} else {
log.Println(v...)
}

View File

@@ -7,104 +7,124 @@ import (
"errors"
"fmt"
"io"
"log"
"math"
"net"
"strings"
"time"
"reflect"
"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
)
var driverInstance = &MssqlDriver{}
func init() {
sql.Register("mssql", &MssqlDriver{})
sql.Register("mssql", driverInstance)
}
type MssqlDriver struct {
log *log.Logger
log optionalLogger
}
func (d *MssqlDriver) SetLogger(logger *log.Logger) {
d.log = logger
func SetLogger(logger Logger) {
driverInstance.SetLogger(logger)
}
func CheckBadConn(err error) error {
if err == io.EOF {
return driver.ErrBadConn
}
switch e := err.(type) {
case net.Error:
if e.Timeout() {
return e
}
return driver.ErrBadConn
default:
return err
}
func (d *MssqlDriver) SetLogger(logger Logger) {
d.log = optionalLogger{logger}
}
type MssqlConn struct {
sess *tdsSession
transactionCtx context.Context
}
func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, c.sess, tokchan)
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return token.getError()
}
case error:
return token
}
}
return nil
}
func (c *MssqlConn) Commit() error {
if err := c.sendCommitRequest(); err != nil {
return err
}
return c.simpleProcessResp(c.transactionCtx)
}
func (c *MssqlConn) sendCommitRequest() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
return err
}
tokchan := make(chan tokenStruct, 5)
go processResponse(c.sess, tokchan)
for tok := range tokchan {
switch token := tok.(type) {
case error:
return token
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
return driver.ErrBadConn
}
return nil
}
func (c *MssqlConn) Rollback() error {
if err := c.sendRollbackRequest(); err != nil {
return err
}
return c.simpleProcessResp(c.transactionCtx)
}
func (c *MssqlConn) sendRollbackRequest() error {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
return err
}
tokchan := make(chan tokenStruct, 5)
go processResponse(c.sess, tokchan)
for tok := range tokchan {
switch token := tok.(type) {
case error:
return token
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
return driver.ErrBadConn
}
return nil
}
func (c *MssqlConn) Begin() (driver.Tx, error) {
return c.begin(context.Background(), isolationUseCurrent)
}
func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {
err := c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, err
}
return c.processBeginResponse(ctx)
}
func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
c.transactionCtx = ctx
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
if err := sendBeginXact(c.sess.buf, headers, 0, ""); err != nil {
return nil, CheckBadConn(err)
}
tokchan := make(chan tokenStruct, 5)
go processResponse(c.sess, tokchan)
for tok := range tokchan {
switch token := tok.(type) {
case error:
if c.sess.tranid != 0 {
return nil, token
}
return nil, CheckBadConn(token)
if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
return driver.ErrBadConn
}
return nil
}
func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
if err := c.simpleProcessResp(ctx); err != nil {
return nil, err
}
// successful BEGINXACT request will return sess.tranid
// for started transaction
@@ -112,6 +132,10 @@ func (c *MssqlConn) Begin() (driver.Tx, error) {
}
func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
return d.open(dsn)
}
func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
params, err := parseConnectParams(dsn)
if err != nil {
return nil, err
@@ -136,8 +160,8 @@ func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
}
}
conn := &MssqlConn{sess}
conn.sess.log = (*Logger)(d.log)
conn := &MssqlConn{sess, context.Background()}
conn.sess.log = d.log
return conn, nil
}
@@ -159,6 +183,10 @@ type queryNotifSub struct {
}
func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
return c.prepareContext(context.Background(), query)
}
func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
q, paramCount := parseParams(query)
return &MssqlStmt{c, q, paramCount, nil}, nil
}
@@ -179,7 +207,7 @@ func (s *MssqlStmt) NumInput() int {
return s.paramCount
}
func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) {
func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
@@ -190,9 +218,7 @@ func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) {
data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
}
if len(args) != s.paramCount {
return errors.New(fmt.Sprintf("sql: expected %d parameters, got %d", s.paramCount, len(args)))
}
// no need to check number of parameters here, it is checked by database/sql
if s.c.sess.logFlags&logSQL != 0 {
s.c.sess.log.Println(s.query)
}
@@ -204,48 +230,71 @@ func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) {
}
if len(args) == 0 {
if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
if s.c.sess.tranid != 0 {
return err
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
}
return CheckBadConn(err)
return driver.ErrBadConn
}
} else {
params := make([]Param, len(args)+2)
decls := make([]string, len(args))
params[0], err = s.makeParam(s.query)
if err != nil {
return
}
params[0] = makeStrParam(s.query)
for i, val := range args {
params[i+2], err = s.makeParam(val)
params[i+2], err = s.makeParam(val.Value)
if err != nil {
return
}
name := fmt.Sprintf("@p%d", i+1)
var name string
if len(val.Name) > 0 {
name = "@" + val.Name
} else {
name = fmt.Sprintf("@p%d", val.Ordinal)
}
params[i+2].Name = name
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
}
params[1], err = s.makeParam(strings.Join(decls, ","))
if err != nil {
return
}
params[1] = makeStrParam(strings.Join(decls, ","))
if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
if s.c.sess.tranid != 0 {
return err
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Printf("Failed to send Rpc with %v", err)
}
return CheckBadConn(err)
return driver.ErrBadConn
}
}
return
}
func (s *MssqlStmt) Query(args []driver.Value) (res driver.Rows, err error) {
res = nil
if err = s.sendQuery(args); err != nil {
return
type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
func convertOldArgs(args []driver.Value) []namedValue {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i+1,
Value: v,
}
}
return list
}
func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.queryContext(context.Background(), convertOldArgs(args))
}
func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
if err := s.sendQuery(args); err != nil {
return nil, err
}
return s.processQueryResponse(ctx)
}
func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(s.c.sess, tokchan)
go processResponse(ctx, s.c.sess, tokchan)
// process metadata
var cols []columnStruct
loop:
@@ -261,23 +310,32 @@ loop:
case []columnStruct:
cols = token
break loop
case error:
if s.c.sess.tranid != 0 {
return nil, token
case doneStruct:
if token.isError() {
return nil, token.getError()
}
return nil, CheckBadConn(token)
case error:
return nil, token
}
}
res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols}
return
}
func (s *MssqlStmt) Exec(args []driver.Value) (res driver.Result, err error) {
if err = s.sendQuery(args); err != nil {
return
func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.exec(context.Background(), convertOldArgs(args))
}
func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if err := s.sendQuery(args); err != nil {
return nil, err
}
return s.processExec(ctx)
}
func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(s.c.sess, tokchan)
go processResponse(ctx, s.c.sess, tokchan)
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
@@ -289,14 +347,11 @@ func (s *MssqlStmt) Exec(args []driver.Value) (res driver.Result, err error) {
if token.Status&doneCount != 0 {
rowCount = int64(token.RowCount)
}
if token.isError() {
return nil, token.getError()
}
case error:
if s.c.sess.logFlags&logErrors != 0 {
s.c.sess.log.Println("got error:", token)
}
if s.c.sess.tranid != 0 {
return nil, token
}
return nil, CheckBadConn(token)
return nil, token
}
}
return &MssqlResult{s.c, rowCount}, nil
@@ -339,6 +394,10 @@ func (rc *MssqlRows) Next(dest []driver.Value) (error) {
dest[i] = tokdata[i]
}
return nil
case doneStruct:
if tokdata.isError() {
return tokdata.getError()
}
case error:
return tokdata
}
@@ -410,6 +469,13 @@ func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
return
}
func makeStrParam(val string) (res Param) {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
res.ti.Size = len(res.buffer)
return
}
func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
if val == nil {
res.ti.TypeId = typeNVarChar
@@ -433,9 +499,7 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
res.ti.Size = len(val)
res.buffer = val
case string:
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
res.ti.Size = len(res.buffer)
res = makeStrParam(val)
case bool:
res.ti.TypeId = typeBitN
res.ti.Size = 1

69
vendor/github.com/denisenkom/go-mssqldb/mssql_go18.go generated vendored Normal file
View File

@@ -0,0 +1,69 @@
// +build go1.8
package mssql
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
)
var _ driver.Pinger = &MssqlConn{}
func (c *MssqlConn) Ping(ctx context.Context) error {
stmt := &MssqlStmt{c, `select 1;`, 0, nil}
_, err := stmt.ExecContext(ctx, nil)
return err
}
func (c *MssqlConn) BeginContext(ctx context.Context) (driver.Tx, error) {
if driver.ReadOnlyFromContext(ctx) {
return nil, errors.New("Read-only transactions are not supported")
}
tdsIsolation := isolationUseCurrent
isolation, ok := driver.IsolationFromContext(ctx)
if ok {
switch sql.IsolationLevel(isolation) {
case sql.LevelDefault:
tdsIsolation = isolationUseCurrent
case sql.LevelReadUncommitted:
tdsIsolation = isolationReadUncommited
case sql.LevelReadCommitted:
tdsIsolation = isolationReadCommited
case sql.LevelWriteCommitted:
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
tdsIsolation = isolationRepeatableRead
case sql.LevelSnapshot:
tdsIsolation = isolationSnapshot
case sql.LevelSerializable:
tdsIsolation = isolationSerializable
case sql.LevelLinearizable:
return nil, errors.New("LevelLinearizable isolation level is not supported")
default:
return nil, errors.New("Isolation level is not supported or unknown")
}
}
return c.begin(ctx, tdsIsolation)
}
func (c *MssqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepareContext(ctx, query)
}
func (s *MssqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.queryContext(ctx, list)
}
func (s *MssqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.exec(ctx, list)
}

View File

@@ -33,7 +33,7 @@ func (c *timeoutConn) Read(b []byte) (n int, err error) {
c.continueRead = false
}
if !c.continueRead {
var packet uint8
var packet packetType
packet, err = c.buf.BeginRead()
if err != nil {
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())

View File

@@ -11,6 +11,9 @@ type parser struct {
w bytes.Buffer
paramCount int
paramMax int
// using map as a set
namedParams map [string]bool
}
func (p *parser) next() (rune, bool) {
@@ -40,12 +43,13 @@ type stateFunc func(*parser) stateFunc
func parseParams(query string) (string, int) {
p := &parser{
r: bytes.NewReader([]byte(query)),
namedParams: map [string]bool{},
}
state := parseNormal
for state != nil {
state = state(p)
}
return p.w.String(), p.paramMax
return p.w.String(), p.paramMax + len(p.namedParams)
}
func parseNormal(p *parser) stateFunc {
@@ -55,7 +59,7 @@ func parseNormal(p *parser) stateFunc {
return nil
}
if ch == '?' {
return parseParameter
return parseOrdinalParameter
} else if ch == '$' || ch == ':' {
ch2, ok := p.next()
if !ok {
@@ -64,7 +68,9 @@ func parseNormal(p *parser) stateFunc {
}
p.unread()
if ch2 >= '0' && ch2 <= '9' {
return parseParameter
return parseOrdinalParameter
} else if 'a' <= ch2 && ch2 <= 'z' || 'A' <= ch2 && ch2 <= 'Z' {
return parseNamedParameter
}
}
p.write(ch)
@@ -83,7 +89,7 @@ func parseNormal(p *parser) stateFunc {
}
}
func parseParameter(p *parser) stateFunc {
func parseOrdinalParameter(p *parser) stateFunc {
var paramN int
var ok bool
for {
@@ -113,6 +119,30 @@ func parseParameter(p *parser) stateFunc {
return parseNormal
}
func parseNamedParameter(p *parser) stateFunc {
var paramName string
var ok bool
for {
var ch rune
ch, ok = p.next()
if ok && (ch >= '0' && ch <= '9' || 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z') {
paramName = paramName + string(ch)
} else {
break
}
}
if ok {
p.unread()
}
p.namedParams[paramName] = true
p.w.WriteString("@")
p.w.WriteString(paramName)
if !ok {
return nil
}
return parseNormal
}
func parseQuote(p *parser) stateFunc {
for {
ch, ok := p.next()

View File

@@ -16,6 +16,7 @@ import (
"time"
"unicode/utf16"
"unicode/utf8"
"golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
)
func parseInstances(msg []byte) map[string]map[string]string {
@@ -79,11 +80,16 @@ const (
)
// packet types
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
const (
packSQLBatch = 1
packSQLBatch packetType = 1
packRPCRequest = 3
packReply = 4
packCancel = 6
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
packAttention = 6
packBulkLoadBCP = 7
packTransMgrReq = 14
packNormal = 15
@@ -119,7 +125,7 @@ type tdsSession struct {
columns []columnStruct
tranid uint64
logFlags uint64
log *Logger
log optionalLogger
routedServer string
routedPort uint16
}
@@ -131,6 +137,7 @@ const (
logSQL = 8
logParams = 16
logTransaction = 32
logDebug = 64
)
type columnStruct struct {
@@ -632,6 +639,13 @@ func sendSqlBatch72(buf *tdsBuffer,
return buf.FinishPacket()
}
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
func sendAttention(buf *tdsBuffer) (error) {
buf.BeginPacket(packAttention)
return buf.FinishPacket()
}
type connectParams struct {
logFlags uint64
port uint64
@@ -712,7 +726,8 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
}
p.dial_timeout = 5 * time.Second
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.dial_timeout = 15 * time.Second
p.conn_timeout = 30 * time.Second
strconntimeout, ok := params["connection timeout"]
if ok {
@@ -732,6 +747,11 @@ func parseConnectParams(dsn string) (connectParams, error) {
}
p.dial_timeout = time.Duration(timeout) * time.Second
}
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.keepAlive = 30 * time.Second
keepAlive, ok := params["keepalive"]
if ok {
timeout, err := strconv.ParseUint(keepAlive, 0, 16)
@@ -1028,7 +1048,7 @@ initiate_connection:
var sspi_msg []byte
continue_login:
tokchan := make(chan tokenStruct, 5)
go processResponse(&sess, tokchan)
go processResponse(context.Background(), &sess, tokchan)
success := false
for tok := range tokchan {
switch token := tok.(type) {

View File

@@ -5,6 +5,9 @@ import (
"io"
"strconv"
"strings"
"golang.org/x/net/context"
"net"
"errors"
)
// token ids
@@ -25,6 +28,7 @@ const (
)
// done flags
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
const (
doneFinal = 0
doneMore = 1
@@ -77,6 +81,19 @@ type doneStruct struct {
Status uint16
CurCmd uint16
RowCount uint64
errors []Error
}
func (d doneStruct) isError() bool {
return d.Status&doneError != 0 || len(d.errors) > 0
}
func (d doneStruct) getError() Error {
if len(d.errors) > 0 {
return d.errors[len(d.errors) - 1]
} else {
return Error{Message: "Request failed but didn't provide reason"}
}
}
type doneInProcStruct doneStruct
@@ -365,6 +382,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) {
return res
}
// https://msdn.microsoft.com/en-us/library/dd340421.aspx
func parseDone(r *tdsBuffer) (res doneStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
@@ -372,6 +390,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) {
return res
}
// https://msdn.microsoft.com/en-us/library/dd340553.aspx
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
@@ -480,15 +499,22 @@ func parseInfo(r *tdsBuffer) (res Error) {
return
}
func processResponse(sess *tdsSession, ch chan tokenStruct) {
func processSingleResponse(sess *tdsSession, ch chan tokenStruct) {
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Printf("ERROR: Intercepted panick %v", err)
}
ch <- err
}
close(ch)
}()
packet_type, err := sess.buf.BeginRead()
if err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Printf("ERROR: BeginRead failed %v", err)
}
ch <- err
return
}
@@ -496,10 +522,12 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
badStreamPanicf("invalid response packet type, expected REPLY, actual: %d", packet_type)
}
var columns []columnStruct
var lastError Error
var failed bool
errs := make([]Error, 0, 5)
for {
token := sess.buf.byte()
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got token id %d", token)
}
switch token {
case tokenSSPI:
ch <- parseSSPIMsg(sess.buf)
@@ -521,18 +549,17 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
ch <- done
case tokenDone, tokenDoneProc:
done := parseDone(sess.buf)
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
}
if done.Status&doneError != 0 || failed {
ch <- lastError
return
done.errors = errs
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
}
if done.Status&doneSrvError != 0 {
lastError.Message = "Server Error"
ch <- lastError
ch <- errors.New("SQL Server had internal error")
return
}
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
}
ch <- done
if done.Status&doneMore == 0 {
return
@@ -551,13 +578,19 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
case tokenEnvChange:
processEnvChg(sess)
case tokenError:
lastError = parseError72(sess.buf)
failed = true
err := parseError72(sess.buf)
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got ERROR %d %s", err.Number, err.Message)
}
errs = append(errs, err)
if sess.logFlags&logErrors != 0 {
sess.log.Println(lastError.Message)
sess.log.Println(err.Message)
}
case tokenInfo:
info := parseInfo(sess.buf)
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got INFO %d %s", info.Number, info.Message)
}
if sess.logFlags&logMessages != 0 {
sess.log.Println(info.Message)
}
@@ -566,3 +599,95 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
}
}
}
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct) {
defer func() {
close(ch)
}()
doneChan := ctx.Done()
cancelInProgress := false
cancelledByContext := false
var cancelError error
// loop over multiple responses
for {
if sess.logFlags&logDebug != 0 {
sess.log.Println("initiating resonse reading")
}
tokChan := make(chan tokenStruct)
go processSingleResponse(sess, tokChan)
// loop over multiple tokens in response
tokensLoop:
for {
select {
case tok, ok := <-tokChan:
if ok {
if cancelInProgress {
switch tok := tok.(type) {
case doneStruct:
if tok.Status&doneAttn != 0 {
if sess.logFlags&logDebug != 0 {
sess.log.Println("got cancellation confirmation from server")
}
if cancelledByContext {
ch <- ctx.Err()
} else {
ch <- cancelError
}
return
}
}
} else {
if err, ok := tok.(net.Error); ok && err.Timeout() {
cancelError = err
if sess.logFlags&logDebug != 0 {
sess.log.Println("got timeout error, sending attention signal to server")
}
err := sendAttention(sess.buf)
if err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Println("Failed to send attention signal %v", err)
}
ch <- err
return
}
doneChan = nil
cancelInProgress = true
cancelledByContext = false
} else {
ch <- tok
}
}
} else {
// response finished
if cancelInProgress {
if sess.logFlags&logDebug != 0 {
sess.log.Println("response finished but waiting for attention ack")
}
break tokensLoop
} else {
if sess.logFlags&logDebug != 0 {
sess.log.Println("response finished")
}
return
}
}
case <-doneChan:
if sess.logFlags&logDebug != 0 {
sess.log.Println("got cancel message, sending attention signal to server")
}
err := sendAttention(sess.buf)
if err != nil {
if sess.logFlags&logErrors != 0 {
sess.log.Println("Failed to send attention signal %v", err)
}
ch <- err
return
}
doneChan = nil
cancelInProgress = true
cancelledByContext = true
}
}
}
}

View File

@@ -16,7 +16,18 @@ const (
tmSaveXact = 9
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation uint8,
type isoLevel uint8
const (
isolationUseCurrent isoLevel = 0
isolationReadUncommited = 1
isolationReadCommited = 2
isolationRepeatableRead = 3
isolationSerializable = 4
isolationSnapshot = 5
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel,
name string) (err error) {
buf.BeginPacket(packTransMgrReq)
writeAllHeaders(buf, headers)

View File

@@ -336,12 +336,14 @@ var oidEncryptionAlgorithmDESCBC = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 7}
var oidEncryptionAlgorithmDESEDE3CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 3, 7}
var oidEncryptionAlgorithmAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42}
var oidEncryptionAlgorithmAES128GCM = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 6}
var oidEncryptionAlgorithmAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2}
func (eci encryptedContentInfo) decrypt(key []byte) ([]byte, error) {
alg := eci.ContentEncryptionAlgorithm.Algorithm
if !alg.Equal(oidEncryptionAlgorithmDESCBC) &&
!alg.Equal(oidEncryptionAlgorithmDESEDE3CBC) &&
!alg.Equal(oidEncryptionAlgorithmAES256CBC) &&
!alg.Equal(oidEncryptionAlgorithmAES128CBC) &&
!alg.Equal(oidEncryptionAlgorithmAES128GCM) {
fmt.Printf("Unsupported Content Encryption Algorithm: %s\n", alg)
return nil, ErrUnsupportedAlgorithm
@@ -378,7 +380,7 @@ func (eci encryptedContentInfo) decrypt(key []byte) ([]byte, error) {
block, err = des.NewTripleDESCipher(key)
case alg.Equal(oidEncryptionAlgorithmAES256CBC):
fallthrough
case alg.Equal(oidEncryptionAlgorithmAES128GCM):
case alg.Equal(oidEncryptionAlgorithmAES128GCM), alg.Equal(oidEncryptionAlgorithmAES128CBC):
block, err = aes.NewCipher(key)
}

View File

@@ -45,7 +45,11 @@ func indirect(v reflect.Value, decodingNull bool) (json.Unmarshaler, encoding.Te
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
if v.CanSet() {
v.Set(reflect.New(v.Type().Elem()))
} else {
v = reflect.New(v.Type().Elem())
}
}
if v.Type().NumMethod() > 0 {
if u, ok := v.Interface().(json.Unmarshaler); ok {

View File

@@ -180,6 +180,12 @@ func (err *Error) Stack() []byte {
return buf.Bytes()
}
// Callers satisfies the bugsnag ErrorWithCallerS() interface
// so that the stack can be read out.
func (err *Error) Callers() []uintptr {
return err.stack
}
// ErrorStack returns a string that contains both the
// error message and the callstack.
func (err *Error) ErrorStack() string {

67
vendor/github.com/go-ldap/ldap/dn.go generated vendored
View File

@@ -175,3 +175,70 @@ func ParseDN(str string) (*DN, error) {
}
return dn, nil
}
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Returns true if they have the same number of relative distinguished names
// and corresponding relative distinguished names (by position) are the same.
func (d *DN) Equal(other *DN) bool {
if len(d.RDNs) != len(other.RDNs) {
return false
}
for i := range d.RDNs {
if !d.RDNs[i].Equal(other.RDNs[i]) {
return false
}
}
return true
}
// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com"
// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com"
// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com"
func (d *DN) AncestorOf(other *DN) bool {
if len(d.RDNs) >= len(other.RDNs) {
return false
}
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
for i := range d.RDNs {
if !d.RDNs[i].Equal(otherRDNs[i]) {
return false
}
}
return true
}
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues
// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type.
// The order of attributes is not significant.
// Case of attribute types is not significant.
func (r *RelativeDN) Equal(other *RelativeDN) bool {
if len(r.Attributes) != len(other.Attributes) {
return false
}
return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes)
}
func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
for _, attr := range attrs {
found := false
for _, myattr := range r.Attributes {
if myattr.Equal(attr) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
// Case of the attribute type is not significant
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
}

View File

@@ -103,6 +103,10 @@ func (s *ActivityService) GetRepositorySubscription(owner, repo string) (*Subscr
// SetRepositorySubscription sets the subscription for the specified repository
// for the authenticated user.
//
// To watch a repository, set subscription.Subscribed to true.
// To ignore notifications made within a repository, set subscription.Ignored to true.
// To stop watching a repository, use DeleteRepositorySubscription.
//
// GitHub API Docs: https://developer.github.com/v3/activity/watching/#set-a-repository-subscription
func (s *ActivityService) SetRepositorySubscription(owner, repo string, subscription *Subscription) (*Subscription, *Response, error) {
u := fmt.Sprintf("repos/%s/%s/subscription", owner, repo)
@@ -124,6 +128,9 @@ func (s *ActivityService) SetRepositorySubscription(owner, repo string, subscrip
// DeleteRepositorySubscription deletes the subscription for the specified
// repository for the authenticated user.
//
// This is used to stop watching a repository. To control whether or not to
// receive notifications from a repository, use SetRepositorySubscription.
//
// GitHub API Docs: https://developer.github.com/v3/activity/watching/#delete-a-repository-subscription
func (s *ActivityService) DeleteRepositorySubscription(owner, repo string) (*Response, error) {
u := fmt.Sprintf("repos/%s/%s/subscription", owner, repo)

View File

@@ -86,6 +86,21 @@ To detect an API rate limit error, you can check if its type is *github.RateLimi
Learn more about GitHub rate limiting at
http://developer.github.com/v3/#rate-limiting.
Accepted Status
Some endpoints may return a 202 Accepted status code, meaning that the
information required is not yet ready and was scheduled to be gathered on
the GitHub side. Methods known to behave like this are documented specifying
this behavior.
To detect this condition of error, you can check if its type is
*github.AcceptedError:
stats, _, err := client.Repositories.ListContributorsStats(org, repo)
if _, ok := err.(*github.AcceptedError); ok {
log.Println("scheduled on GitHub side")
}
Conditional Requests
The GitHub API has good support for conditional requests which will help

View File

@@ -501,6 +501,18 @@ func (r *RateLimitError) Error() string {
r.Response.StatusCode, r.Message, r.Rate.Reset.Time.Sub(time.Now()))
}
// AcceptedError occurs when GitHub returns 202 Accepted response with an
// empty body, which means a job was scheduled on the GitHub side to process
// the information needed and cache it.
// Technically, 202 Accepted is not a real error, it's just used to
// indicate that results are not ready yet, but should be available soon.
// The request can be repeated after some time.
type AcceptedError struct{}
func (*AcceptedError) Error() string {
return "job scheduled on GitHub side; try again later"
}
// AbuseRateLimitError occurs when GitHub returns 403 Forbidden response with the
// "documentation_url" field value equal to "https://developer.github.com/v3#abuse-rate-limits".
type AbuseRateLimitError struct {
@@ -564,14 +576,19 @@ func (e *Error) Error() string {
}
// CheckResponse checks the API response for errors, and returns them if
// present. A response is considered an error if it has a status code outside
// the 200 range. API error responses are expected to have either no response
// present. A response is considered an error if it has a status code outside
// the 200 range or equal to 202 Accepted.
// API error responses are expected to have either no response
// body, or a JSON response body that maps to ErrorResponse. Any other
// response body will be silently ignored.
//
// The error type will be *RateLimitError for rate limit exceeded errors,
// *AcceptedError for 202 Accepted status codes,
// and *TwoFactorAuthError for two-factor authentication errors.
func CheckResponse(r *http.Response) error {
if r.StatusCode == http.StatusAccepted {
return &AcceptedError{}
}
if c := r.StatusCode; 200 <= c && c <= 299 {
return nil
}

View File

@@ -36,7 +36,7 @@ func (p Project) String() string {
//
// GitHub API docs: https://developer.github.com/v3/projects/#get-a-project
func (s *ProjectsService) GetProject(id int) (*Project, *Response, error) {
u := fmt.Sprintf("/projects/%v", id)
u := fmt.Sprintf("projects/%v", id)
req, err := s.client.NewRequest("GET", u, nil)
if err != nil {
return nil, nil, err
@@ -68,7 +68,7 @@ type ProjectOptions struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/#update-a-project
func (s *ProjectsService) UpdateProject(id int, opt *ProjectOptions) (*Project, *Response, error) {
u := fmt.Sprintf("/projects/%v", id)
u := fmt.Sprintf("projects/%v", id)
req, err := s.client.NewRequest("PATCH", u, opt)
if err != nil {
return nil, nil, err
@@ -90,7 +90,7 @@ func (s *ProjectsService) UpdateProject(id int, opt *ProjectOptions) (*Project,
//
// GitHub API docs: https://developer.github.com/v3/projects/#delete-a-project
func (s *ProjectsService) DeleteProject(id int) (*Response, error) {
u := fmt.Sprintf("/projects/%v", id)
u := fmt.Sprintf("projects/%v", id)
req, err := s.client.NewRequest("DELETE", u, nil)
if err != nil {
return nil, err
@@ -117,7 +117,7 @@ type ProjectColumn struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/columns/#list-project-columns
func (s *ProjectsService) ListProjectColumns(projectId int, opt *ListOptions) ([]*ProjectColumn, *Response, error) {
u := fmt.Sprintf("/projects/%v/columns", projectId)
u := fmt.Sprintf("projects/%v/columns", projectId)
u, err := addOptions(u, opt)
if err != nil {
return nil, nil, err
@@ -144,7 +144,7 @@ func (s *ProjectsService) ListProjectColumns(projectId int, opt *ListOptions) ([
//
// GitHub API docs: https://developer.github.com/v3/projects/columns/#get-a-project-column
func (s *ProjectsService) GetProjectColumn(id int) (*ProjectColumn, *Response, error) {
u := fmt.Sprintf("/projects/columns/%v", id)
u := fmt.Sprintf("projects/columns/%v", id)
req, err := s.client.NewRequest("GET", u, nil)
if err != nil {
return nil, nil, err
@@ -174,7 +174,7 @@ type ProjectColumnOptions struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/columns/#create-a-project-column
func (s *ProjectsService) CreateProjectColumn(projectId int, opt *ProjectColumnOptions) (*ProjectColumn, *Response, error) {
u := fmt.Sprintf("/projects/%v/columns", projectId)
u := fmt.Sprintf("projects/%v/columns", projectId)
req, err := s.client.NewRequest("POST", u, opt)
if err != nil {
return nil, nil, err
@@ -196,7 +196,7 @@ func (s *ProjectsService) CreateProjectColumn(projectId int, opt *ProjectColumnO
//
// GitHub API docs: https://developer.github.com/v3/projects/columns/#update-a-project-column
func (s *ProjectsService) UpdateProjectColumn(columnID int, opt *ProjectColumnOptions) (*ProjectColumn, *Response, error) {
u := fmt.Sprintf("/projects/columns/%v", columnID)
u := fmt.Sprintf("projects/columns/%v", columnID)
req, err := s.client.NewRequest("PATCH", u, opt)
if err != nil {
return nil, nil, err
@@ -218,7 +218,7 @@ func (s *ProjectsService) UpdateProjectColumn(columnID int, opt *ProjectColumnOp
//
// GitHub API docs: https://developer.github.com/v3/projects/columns/#delete-a-project-column
func (s *ProjectsService) DeleteProjectColumn(columnID int) (*Response, error) {
u := fmt.Sprintf("/projects/columns/%v", columnID)
u := fmt.Sprintf("projects/columns/%v", columnID)
req, err := s.client.NewRequest("DELETE", u, nil)
if err != nil {
return nil, err
@@ -242,7 +242,7 @@ type ProjectColumnMoveOptions struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/columns/#move-a-project-column
func (s *ProjectsService) MoveProjectColumn(columnID int, opt *ProjectColumnMoveOptions) (*Response, error) {
u := fmt.Sprintf("/projects/columns/%v/moves", columnID)
u := fmt.Sprintf("projects/columns/%v/moves", columnID)
req, err := s.client.NewRequest("POST", u, opt)
if err != nil {
return nil, err
@@ -270,7 +270,7 @@ type ProjectCard struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/cards/#list-project-cards
func (s *ProjectsService) ListProjectCards(columnID int, opt *ListOptions) ([]*ProjectCard, *Response, error) {
u := fmt.Sprintf("/projects/columns/%v/cards", columnID)
u := fmt.Sprintf("projects/columns/%v/cards", columnID)
u, err := addOptions(u, opt)
if err != nil {
return nil, nil, err
@@ -297,7 +297,7 @@ func (s *ProjectsService) ListProjectCards(columnID int, opt *ListOptions) ([]*P
//
// GitHub API docs: https://developer.github.com/v3/projects/cards/#get-a-project-card
func (s *ProjectsService) GetProjectCard(columnID int) (*ProjectCard, *Response, error) {
u := fmt.Sprintf("/projects/columns/cards/%v", columnID)
u := fmt.Sprintf("projects/columns/cards/%v", columnID)
req, err := s.client.NewRequest("GET", u, nil)
if err != nil {
return nil, nil, err
@@ -332,7 +332,7 @@ type ProjectCardOptions struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/cards/#create-a-project-card
func (s *ProjectsService) CreateProjectCard(columnID int, opt *ProjectCardOptions) (*ProjectCard, *Response, error) {
u := fmt.Sprintf("/projects/columns/%v/cards", columnID)
u := fmt.Sprintf("projects/columns/%v/cards", columnID)
req, err := s.client.NewRequest("POST", u, opt)
if err != nil {
return nil, nil, err
@@ -354,7 +354,7 @@ func (s *ProjectsService) CreateProjectCard(columnID int, opt *ProjectCardOption
//
// GitHub API docs: https://developer.github.com/v3/projects/cards/#update-a-project-card
func (s *ProjectsService) UpdateProjectCard(cardID int, opt *ProjectCardOptions) (*ProjectCard, *Response, error) {
u := fmt.Sprintf("/projects/columns/cards/%v", cardID)
u := fmt.Sprintf("projects/columns/cards/%v", cardID)
req, err := s.client.NewRequest("PATCH", u, opt)
if err != nil {
return nil, nil, err
@@ -376,7 +376,7 @@ func (s *ProjectsService) UpdateProjectCard(cardID int, opt *ProjectCardOptions)
//
// GitHub API docs: https://developer.github.com/v3/projects/cards/#delete-a-project-card
func (s *ProjectsService) DeleteProjectCard(cardID int) (*Response, error) {
u := fmt.Sprintf("/projects/columns/cards/%v", cardID)
u := fmt.Sprintf("projects/columns/cards/%v", cardID)
req, err := s.client.NewRequest("DELETE", u, nil)
if err != nil {
return nil, err
@@ -404,7 +404,7 @@ type ProjectCardMoveOptions struct {
//
// GitHub API docs: https://developer.github.com/v3/projects/cards/#move-a-project-card
func (s *ProjectsService) MoveProjectCard(cardID int, opt *ProjectCardMoveOptions) (*Response, error) {
u := fmt.Sprintf("/projects/columns/cards/%v/moves", cardID)
u := fmt.Sprintf("projects/columns/cards/%v/moves", cardID)
req, err := s.client.NewRequest("POST", u, opt)
if err != nil {
return nil, err

View File

@@ -501,9 +501,9 @@ func (s *RepositoriesService) ListTags(owner string, repo string, opt *ListOptio
// Branch represents a repository branch
type Branch struct {
Name *string `json:"name,omitempty"`
Commit *Commit `json:"commit,omitempty"`
Protected *bool `json:"protected,omitempty"`
Name *string `json:"name,omitempty"`
Commit *RepositoryCommit `json:"commit,omitempty"`
Protected *bool `json:"protected,omitempty"`
}
// Protection represents a repository branch's protection.

View File

@@ -20,7 +20,6 @@ type RepositoryCommit struct {
Author *User `json:"author,omitempty"`
Committer *User `json:"committer,omitempty"`
Parents []Commit `json:"parents,omitempty"`
Message *string `json:"message,omitempty"`
HTMLURL *string `json:"html_url,omitempty"`
URL *string `json:"url,omitempty"`
CommentsURL *string `json:"comments_url,omitempty"`
@@ -48,13 +47,16 @@ func (c CommitStats) String() string {
// CommitFile represents a file modified in a commit.
type CommitFile struct {
SHA *string `json:"sha,omitempty"`
Filename *string `json:"filename,omitempty"`
Additions *int `json:"additions,omitempty"`
Deletions *int `json:"deletions,omitempty"`
Changes *int `json:"changes,omitempty"`
Status *string `json:"status,omitempty"`
Patch *string `json:"patch,omitempty"`
SHA *string `json:"sha,omitempty"`
Filename *string `json:"filename,omitempty"`
Additions *int `json:"additions,omitempty"`
Deletions *int `json:"deletions,omitempty"`
Changes *int `json:"changes,omitempty"`
Status *string `json:"status,omitempty"`
Patch *string `json:"patch,omitempty"`
BlobURL *string `json:"blob_url,omitempty"`
RawURL *string `json:"raw_url,omitempty"`
ContentsURL *string `json:"contents_url,omitempty"`
}
func (c CommitFile) String() string {

View File

@@ -267,8 +267,12 @@ func (s *RepositoriesService) GetArchiveLink(owner, repo string, archiveformat a
} else {
resp, err = s.client.client.Transport.RoundTrip(req)
}
if err != nil || resp.StatusCode != http.StatusFound {
return nil, newResponse(resp), err
if err != nil {
return nil, nil, err
}
resp.Body.Close()
if resp.StatusCode != http.StatusFound {
return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status)
}
parsedURL, err := url.Parse(resp.Header.Get("Location"))
return parsedURL, newResponse(resp), err

View File

@@ -82,6 +82,27 @@ func (s *RepositoriesService) ListDeployments(owner, repo string, opt *Deploymen
return *deployments, resp, err
}
// GetDeployment returns a single deployment of a repository.
//
// GitHub API docs: https://developer.github.com/v3/repos/deployments/
// Note: GetDeployment uses the undocumented GitHub API endpoint /repos/:owner/:repo/deployments/:id.
func (s *RepositoriesService) GetDeployment(owner, repo string, deploymentID int) (*Deployment, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/deployments/%v", owner, repo, deploymentID)
req, err := s.client.NewRequest("GET", u, nil)
if err != nil {
return nil, nil, err
}
deployment := new(Deployment)
resp, err := s.client.Do(req, deployment)
if err != nil {
return nil, resp, err
}
return deployment, resp, err
}
// CreateDeployment creates a new deployment for a repository.
//
// GitHub API docs: https://developer.github.com/v3/repos/deployments/#create-a-deployment

View File

@@ -50,6 +50,12 @@ type RepositoryCreateForkOptions struct {
// CreateFork creates a fork of the specified repository.
//
// This method might return an *AcceptedError and a status code of
// 202. This is because this is the status that GitHub returns to signify that
// it is now computing creating the fork in a background task.
// A follow up request, after a delay of a second or so, should result
// in a successful request.
//
// GitHub API docs: https://developer.github.com/v3/repos/forks/#create-a-fork
func (s *RepositoriesService) CreateFork(owner, repo string, opt *RepositoryCreateForkOptions) (*Repository, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/forks", owner, repo)

View File

@@ -30,7 +30,7 @@ type PagesBuild struct {
Commit *string `json:"commit,omitempty"`
Duration *int `json:"duration,omitempty"`
CreatedAt *Timestamp `json:"created_at,omitempty"`
UpdatedAt *Timestamp `json:"created_at,omitempty"`
UpdatedAt *Timestamp `json:"updated_at,omitempty"`
}
// GetPagesInfo fetches information about a GitHub Pages site.

View File

@@ -39,8 +39,8 @@ func (w WeeklyStats) String() string {
// deletions and commit counts.
//
// If this is the first time these statistics are requested for the given
// repository, this method will return a non-nil error and a status code of
// 202. This is because this is the status that github returns to signify that
// repository, this method will return an *AcceptedError and a status code of
// 202. This is because this is the status that GitHub returns to signify that
// it is now computing the requested statistics. A follow up request, after a
// delay of a second or so, should result in a successful request.
//
@@ -78,8 +78,8 @@ func (w WeeklyCommitActivity) String() string {
// starting on Sunday.
//
// If this is the first time these statistics are requested for the given
// repository, this method will return a non-nil error and a status code of
// 202. This is because this is the status that github returns to signify that
// repository, this method will return an *AcceptedError and a status code of
// 202. This is because this is the status that GitHub returns to signify that
// it is now computing the requested statistics. A follow up request, after a
// delay of a second or so, should result in a successful request.
//
@@ -104,6 +104,12 @@ func (s *RepositoriesService) ListCommitActivity(owner, repo string) ([]*WeeklyC
// deletions pushed to a repository. Returned WeeklyStats will contain
// additions and deletions, but not total commits.
//
// If this is the first time these statistics are requested for the given
// repository, this method will return an *AcceptedError and a status code of
// 202. This is because this is the status that GitHub returns to signify that
// it is now computing the requested statistics. A follow up request, after a
// delay of a second or so, should result in a successful request.
//
// GitHub API Docs: https://developer.github.com/v3/repos/statistics/#code-frequency
func (s *RepositoriesService) ListCodeFrequency(owner, repo string) ([]*WeeklyStats, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/stats/code_frequency", owner, repo)
@@ -152,11 +158,10 @@ func (r RepositoryParticipation) String() string {
// The array order is oldest week (index 0) to most recent week.
//
// If this is the first time these statistics are requested for the given
// repository, this method will return a non-nil error and a status code
// of 202. This is because this is the status that github returns to
// signify that it is now computing the requested statistics. A follow
// up request, after a delay of a second or so, should result in a
// successful request.
// repository, this method will return an *AcceptedError and a status code of
// 202. This is because this is the status that GitHub returns to signify that
// it is now computing the requested statistics. A follow up request, after a
// delay of a second or so, should result in a successful request.
//
// GitHub API Docs: https://developer.github.com/v3/repos/statistics/#participation
func (s *RepositoriesService) ListParticipation(owner, repo string) (*RepositoryParticipation, *Response, error) {
@@ -185,6 +190,12 @@ type PunchCard struct {
// ListPunchCard returns the number of commits per hour in each day.
//
// If this is the first time these statistics are requested for the given
// repository, this method will return an *AcceptedError and a status code of
// 202. This is because this is the status that GitHub returns to signify that
// it is now computing the requested statistics. A follow up request, after a
// delay of a second or so, should result in a successful request.
//
// GitHub API Docs: https://developer.github.com/v3/repos/statistics/#punch-card
func (s *RepositoriesService) ListPunchCard(owner, repo string) ([]*PunchCard, *Response, error) {
u := fmt.Sprintf("repos/%v/%v/stats/punch_card", owner, repo)

View File

@@ -199,7 +199,7 @@ func (n *netConn) close() error {
// local machine using a Unix domain socket.
func unixSyslog() (conn serverConn, err error) {
logTypes := []string{"unixgram", "unix"}
logPaths := []string{"/dev/log", "/var/run/syslog"}
logPaths := []string{"/dev/log", "/var/run/syslog", "/var/run/log"}
for _, network := range logTypes {
for _, path := range logPaths {
conn, err := net.DialTimeout(network, path, localDeadline)

View File

@@ -28,6 +28,23 @@ func NewLogger(p Priority, facility, tag string) (Syslogger, error) {
return &builtinLogger{l}, nil
}
// DialLogger is used to construct a new Syslogger that establishes connection to remote syslog server
func DialLogger(network, raddr string, p Priority, facility, tag string) (Syslogger, error) {
fPriority, err := facilityPriority(facility)
if err != nil {
return nil, err
}
priority := syslog.Priority(p) | fPriority
l, err := dialBuiltin(network, raddr, priority, tag)
if err != nil {
return nil, err
}
return &builtinLogger{l}, nil
}
// WriteLevel writes out a message at the given priority
func (b *builtinLogger) WriteLevel(p Priority, buf []byte) error {
var err error

View File

@@ -10,3 +10,8 @@ import (
func NewLogger(p Priority, facility, tag string) (Syslogger, error) {
return nil, fmt.Errorf("Platform does not support syslog")
}
// DialLogger is used to construct a new Syslogger that establishes connection to remote syslog server
func DialLogger(network, raddr string, p Priority, facility, tag string) (Syslogger, error) {
return nil, fmt.Errorf("Platform does not support syslog")
}

View File

@@ -5,6 +5,7 @@
package openpgp
import (
"crypto/hmac"
"github.com/keybase/go-crypto/openpgp/armor"
"github.com/keybase/go-crypto/openpgp/errors"
"github.com/keybase/go-crypto/openpgp/packet"
@@ -67,13 +68,22 @@ type Key struct {
// A KeyRing provides access to public and private keys.
type KeyRing interface {
// KeysById returns the set of keys that have the given key id.
KeysById(id uint64) []Key
// fp can be optionally supplied, which is the full key fingerprint.
// If it's provided, then it must match. This comes up in the case
// of GPG subpacket 33.
KeysById(id uint64, fp []byte) []Key
// KeysByIdAndUsage returns the set of keys with the given id
// that also meet the key usage given by requiredUsage.
// The requiredUsage is expressed as the bitwise-OR of
// packet.KeyFlag* values.
KeysByIdUsage(id uint64, requiredUsage byte) []Key
// fp can be optionally supplied, which is the full key fingerprint.
// If it's provided, then it must match. This comes up in the case
// of GPG subpacket 33.
KeysByIdUsage(id uint64, fp []byte, requiredUsage byte) []Key
// DecryptionKeys returns all private keys that are valid for
// decryption.
DecryptionKeys() []Key
@@ -185,10 +195,23 @@ func (e *Entity) signingKey(now time.Time) (Key, bool) {
// An EntityList contains one or more Entities.
type EntityList []*Entity
func keyMatchesIdAndFingerprint(key *packet.PublicKey, id uint64, fp []byte) bool {
if key.KeyId != id {
return false
}
if fp == nil {
return true
}
return hmac.Equal(fp, key.Fingerprint[:])
}
// KeysById returns the set of keys that have the given key id.
func (el EntityList) KeysById(id uint64) (keys []Key) {
// fp can be optionally supplied, which is the full key fingerprint.
// If it's provided, then it must match. This comes up in the case
// of GPG subpacket 33.
func (el EntityList) KeysById(id uint64, fp []byte) (keys []Key) {
for _, e := range el {
if e.PrimaryKey.KeyId == id {
if keyMatchesIdAndFingerprint(e.PrimaryKey, id, fp) {
var selfSig *packet.Signature
for _, ident := range e.Identities {
if selfSig == nil {
@@ -202,7 +225,7 @@ func (el EntityList) KeysById(id uint64) (keys []Key) {
}
for _, subKey := range e.Subkeys {
if subKey.PublicKey.KeyId == id {
if keyMatchesIdAndFingerprint(subKey.PublicKey, id, fp) {
// If there's both a a revocation and a sig, then take the
// revocation. Otherwise, we can proceed with the sig.
@@ -221,8 +244,11 @@ func (el EntityList) KeysById(id uint64) (keys []Key) {
// KeysByIdAndUsage returns the set of keys with the given id that also meet
// the key usage given by requiredUsage. The requiredUsage is expressed as
// the bitwise-OR of packet.KeyFlag* values.
func (el EntityList) KeysByIdUsage(id uint64, requiredUsage byte) (keys []Key) {
for _, key := range el.KeysById(id) {
// fp can be optionally supplied, which is the full key fingerprint.
// If it's provided, then it must match. This comes up in the case
// of GPG subpacket 33.
func (el EntityList) KeysByIdUsage(id uint64, fp []byte, requiredUsage byte) (keys []Key) {
for _, key := range el.KeysById(id, fp) {
if len(key.Entity.Revocations) > 0 {
continue
}
@@ -263,7 +289,7 @@ func (el EntityList) KeysByIdUsage(id uint64, requiredUsage byte) (keys []Key) {
// For a primary RSA key without any key flags, be as permissiable
// as possible.
case key.PublicKey.PubKeyAlgo == packet.PubKeyAlgoRSA &&
key.Entity.PrimaryKey.KeyId == id:
keyMatchesIdAndFingerprint(key.Entity.PrimaryKey, id, fp):
usage = (packet.KeyFlagCertify | packet.KeyFlagSign |
packet.KeyFlagEncryptCommunications | packet.KeyFlagEncryptStorage)
}

View File

@@ -1,10 +1,9 @@
package packet
import (
"math/big"
"crypto/ecdsa"
"errors"
"math/big"
)
type ecdhPrivateKey struct {

View File

@@ -65,6 +65,7 @@ type Signature struct {
PreferredKeyServer string
IssuerKeyId *uint64
IsPrimaryId *bool
IssuerFingerprint []byte
// FlagsValid is set if any flags were given. See RFC 4880, section
// 5.2.3.21 for details.
@@ -233,6 +234,7 @@ const (
reasonForRevocationSubpacket signatureSubpacketType = 29
featuresSubpacket signatureSubpacketType = 30
embeddedSignatureSubpacket signatureSubpacketType = 32
issuerFingerprint signatureSubpacketType = 33
)
// parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1.
@@ -420,6 +422,10 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
}
case prefKeyServerSubpacket:
sig.PreferredKeyServer = string(subpacket[:])
case issuerFingerprint:
// The first byte is how many bytes the fingerprint is, but we'll just
// read until the end of the subpacket, so we'll ignore it.
sig.IssuerFingerprint = append([]byte{}, subpacket[1:]...)
default:
if isCritical {
err = errors.UnsupportedError("unknown critical signature subpacket type " + strconv.Itoa(int(packetType)))

View File

@@ -7,6 +7,7 @@ package openpgp // import "github.com/keybase/go-crypto/openpgp"
import (
"crypto"
"crypto/hmac"
_ "crypto/sha256"
"hash"
"io"
@@ -122,7 +123,7 @@ ParsePackets:
if p.KeyId == 0 {
keys = keyring.DecryptionKeys()
} else {
keys = keyring.KeysById(p.KeyId)
keys = keyring.KeysById(p.KeyId, nil)
}
for _, k := range keys {
pubKeys = append(pubKeys, keyEnvelopePair{k, p})
@@ -255,7 +256,7 @@ FindLiteralData:
md.IsSigned = true
md.SignedByKeyId = p.KeyId
keys := keyring.KeysByIdUsage(p.KeyId, packet.KeyFlagSign)
keys := keyring.KeysByIdUsage(p.KeyId, nil, packet.KeyFlagSign)
if len(keys) > 0 {
md.SignedBy = &keys[0]
}
@@ -336,7 +337,16 @@ func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) {
var ok bool
if scr.md.Signature, ok = p.(*packet.Signature); ok {
scr.md.SignatureError = scr.md.SignedBy.PublicKey.VerifySignature(scr.h, scr.md.Signature)
var err error
if fingerprint := scr.md.Signature.IssuerFingerprint; fingerprint != nil {
if !hmac.Equal(fingerprint, scr.md.SignedBy.PublicKey.Fingerprint[:]) {
err = errors.StructuralError("bad key fingerprint")
}
}
if err == nil {
err = scr.md.SignedBy.PublicKey.VerifySignature(scr.h, scr.md.Signature)
}
scr.md.SignatureError = err
} else if scr.md.SignatureV3, ok = p.(*packet.SignatureV3); ok {
scr.md.SignatureError = scr.md.SignedBy.PublicKey.VerifySignatureV3(scr.h, scr.md.SignatureV3)
} else {
@@ -367,6 +377,7 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
func checkDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signer *Entity, issuer *uint64, err error) {
var issuerKeyId uint64
var issuerFingerprint []byte
var hashFunc crypto.Hash
var sigType packet.SignatureType
var keys []Key
@@ -390,6 +401,7 @@ func checkDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
issuerKeyId = *sig.IssuerKeyId
hashFunc = sig.Hash
sigType = sig.SigType
issuerFingerprint = sig.IssuerFingerprint
case *packet.SignatureV3:
issuerKeyId = sig.IssuerKeyId
hashFunc = sig.Hash
@@ -398,7 +410,7 @@ func checkDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
return nil, nil, errors.StructuralError("non signature packet found")
}
keys = keyring.KeysByIdUsage(issuerKeyId, packet.KeyFlagSign)
keys = keyring.KeysByIdUsage(issuerKeyId, issuerFingerprint, packet.KeyFlagSign)
if len(keys) > 0 {
break
}

View File

@@ -7,7 +7,7 @@ all: test
.PHONY: test
test: install-dependencies
go test -v
go test -race -v
cover: install-dependencies install-cover
go test -v -test.coverprofile="$(COVER_FILE).prof"

View File

@@ -87,8 +87,8 @@ func dirUnix() (string, error) {
cmd := exec.Command("getent", "passwd", strconv.Itoa(os.Getuid()))
cmd.Stdout = &stdout
if err := cmd.Run(); err != nil {
// If "getent" is missing, ignore it
if err == exec.ErrNotFound {
// If the error is ErrNotFound, we ignore it. Otherwise, return it.
if err != exec.ErrNotFound {
return "", err
}
} else {

View File

@@ -1,5 +1,5 @@
// The mapstructure package exposes functionality to convert an
// abitrary map[string]interface{} into a native Go structure.
// arbitrary map[string]interface{} into a native Go structure.
//
// The Go structure can be arbitrarily complex, containing slices,
// other structs, etc. and the decoder will properly decode nested
@@ -546,7 +546,12 @@ func (d *Decoder) decodePtr(name string, data interface{}, val reflect.Value) er
// into that. Then set the value of the pointer to this type.
valType := val.Type()
valElemType := valType.Elem()
realVal := reflect.New(valElemType)
realVal:=val
if realVal.IsNil() || d.config.ZeroFields {
realVal = reflect.New(valElemType)
}
if err := d.decode(name, data, reflect.Indirect(realVal)); err != nil {
return err
}
@@ -562,20 +567,24 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
valElemType := valType.Elem()
sliceType := reflect.SliceOf(valElemType)
// Check input type
if dataValKind != reflect.Array && dataValKind != reflect.Slice {
// Accept empty map instead of array/slice in weakly typed mode
if d.config.WeaklyTypedInput && dataVal.Kind() == reflect.Map && dataVal.Len() == 0 {
val.Set(reflect.MakeSlice(sliceType, 0, 0))
return nil
} else {
return fmt.Errorf(
"'%s': source data must be an array or slice, got %s", name, dataValKind)
}
}
valSlice:=val
if valSlice.IsNil() || d.config.ZeroFields {
// Make a new slice to hold our result, same size as the original data.
valSlice := reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len())
// Check input type
if dataValKind != reflect.Array && dataValKind != reflect.Slice {
// Accept empty map instead of array/slice in weakly typed mode
if d.config.WeaklyTypedInput && dataVal.Kind() == reflect.Map && dataVal.Len() == 0 {
val.Set(reflect.MakeSlice(sliceType, 0, 0))
return nil
} else {
return fmt.Errorf(
"'%s': source data must be an array or slice, got %s", name, dataValKind)
}
}
// Make a new slice to hold our result, same size as the original data.
valSlice = reflect.MakeSlice(sliceType, dataVal.Len(), dataVal.Len())
}
// Accumulate any errors
errors := make([]string, 0)

View File

@@ -139,3 +139,4 @@ Contributors
- Sam Gunaratne <samgzeit@gmail.com>
- Richard Scothern <richard.scothern@gmail.com>
- Michel Couillard <couillard.michel@voxlog.ca>
- Christopher Waldon <ckwaldon@us.ibm.com>

13
vendor/github.com/ncw/swift/swift.go generated vendored
View File

@@ -1141,6 +1141,19 @@ func (file *ObjectCreateFile) Close() error {
return nil
}
// Headers returns the response headers from the created object if the upload
// has been completed. The Close() method must be called on an ObjectCreateFile
// before this method.
func (file *ObjectCreateFile) Headers() (Headers, error) {
// error out if upload is not complete.
select {
case <-file.done:
default:
return nil, fmt.Errorf("Cannot get metadata, object upload failed or has not yet completed.")
}
return file.headers, nil
}
// Check it satisfies the interface
var _ io.WriteCloser = &ObjectCreateFile{}

View File

@@ -1,7 +1,7 @@
This project was automatically exported from code.google.com/p/go-uuid
# uuid ![build status](https://travis-ci.org/pborman/uuid.svg?branch=master)
The uuid package generates and inspects UUIDs based on [RFC 412](http://tools.ietf.org/html/rfc4122) and DCE 1.1: Authentication and Security Services.
The uuid package generates and inspects UUIDs based on [RFC 4122](http://tools.ietf.org/html/rfc4122) and DCE 1.1: Authentication and Security Services.
###### Install
`go get github.com/pborman/uuid`

View File

@@ -132,8 +132,11 @@ const (
keyPasteEnd
)
var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
var (
crlf = []byte{'\r', '\n'}
pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns utf8.RuneError.
@@ -333,7 +336,7 @@ func (t *Terminal) advanceCursor(places int) {
// So, if we are stopping at the end of a line, we
// need to write a newline so that our cursor can be
// advanced to the next line.
t.outBuf = append(t.outBuf, '\n')
t.outBuf = append(t.outBuf, '\r', '\n')
}
}
@@ -593,6 +596,35 @@ func (t *Terminal) writeLine(line []rune) {
}
}
// writeWithCRLF writes buf to w but replaces all occurances of \n with \r\n.
func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
for len(buf) > 0 {
i := bytes.IndexByte(buf, '\n')
todo := len(buf)
if i >= 0 {
todo = i
}
var nn int
nn, err = w.Write(buf[:todo])
n += nn
if err != nil {
return n, err
}
buf = buf[todo:]
if i >= 0 {
if _, err = w.Write(crlf); err != nil {
return n, err
}
n += 1
buf = buf[1:]
}
}
return n, nil
}
func (t *Terminal) Write(buf []byte) (n int, err error) {
t.lock.Lock()
defer t.lock.Unlock()
@@ -600,7 +632,7 @@ func (t *Terminal) Write(buf []byte) (n int, err error) {
if t.cursorX == 0 && t.cursorY == 0 {
// This is the easy case: there's nothing on the screen that we
// have to move out of the way.
return t.c.Write(buf)
return writeWithCRLF(t.c, buf)
}
// We have a prompt and possibly user input on the screen. We
@@ -620,7 +652,7 @@ func (t *Terminal) Write(buf []byte) (n int, err error) {
}
t.outBuf = t.outBuf[:0]
if n, err = t.c.Write(buf); err != nil {
if n, err = writeWithCRLF(t.c, buf); err != nil {
return
}

View File

@@ -317,10 +317,12 @@ type Framer struct {
// non-Continuation or Continuation on a different stream is
// attempted to be written.
logReads bool
logReads, logWrites bool
debugFramer *Framer // only use for logging written writes
debugFramerBuf *bytes.Buffer
debugFramer *Framer // only use for logging written writes
debugFramerBuf *bytes.Buffer
debugReadLoggerf func(string, ...interface{})
debugWriteLoggerf func(string, ...interface{})
}
func (fr *Framer) maxHeaderListSize() uint32 {
@@ -355,7 +357,7 @@ func (f *Framer) endWrite() error {
byte(length>>16),
byte(length>>8),
byte(length))
if logFrameWrites {
if f.logWrites {
f.logWrite()
}
@@ -378,10 +380,10 @@ func (f *Framer) logWrite() {
f.debugFramerBuf.Write(f.wbuf)
fr, err := f.debugFramer.ReadFrame()
if err != nil {
log.Printf("http2: Framer %p: failed to decode just-written frame", f)
f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f)
return
}
log.Printf("http2: Framer %p: wrote %v", f, summarizeFrame(fr))
f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, summarizeFrame(fr))
}
func (f *Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) }
@@ -399,9 +401,12 @@ const (
// NewFramer returns a Framer that writes frames to w and reads them from r.
func NewFramer(w io.Writer, r io.Reader) *Framer {
fr := &Framer{
w: w,
r: r,
logReads: logFrameReads,
w: w,
r: r,
logReads: logFrameReads,
logWrites: logFrameWrites,
debugReadLoggerf: log.Printf,
debugWriteLoggerf: log.Printf,
}
fr.getReadBuf = func(size uint32) []byte {
if cap(fr.readBuf) >= int(size) {
@@ -483,7 +488,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
return nil, err
}
if fr.logReads {
log.Printf("http2: Framer %p: read %v", fr, summarizeFrame(f))
fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f))
}
if fh.Type == FrameHeaders && fr.ReadMetaHeaders != nil {
return fr.readMetaFrame(f.(*HeadersFrame))
@@ -1419,8 +1424,8 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
hdec.SetEmitEnabled(true)
hdec.SetMaxStringLength(fr.maxHeaderStringLen())
hdec.SetEmitFunc(func(hf hpack.HeaderField) {
if VerboseLogs && logFrameReads {
log.Printf("http2: decoded hpack field %+v", hf)
if VerboseLogs && fr.logReads {
fr.debugReadLoggerf("http2: decoded hpack field %+v", hf)
}
if !httplex.ValidHeaderFieldValue(hf.Value) {
invalid = headerFieldValueError(hf.Value)

View File

@@ -8,6 +8,7 @@ package http2
import (
"crypto/tls"
"io"
"net/http"
)
@@ -39,3 +40,11 @@ func configureServer18(h1 *http.Server, h2 *Server) error {
func shouldLogPanic(panicValue interface{}) bool {
return panicValue != nil && panicValue != http.ErrAbortHandler
}
func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
return req.GetBody
}
func reqBodyIsNoBody(body io.ReadCloser) bool {
return body == http.NoBody
}

View File

@@ -6,7 +6,10 @@
package http2
import "net/http"
import (
"io"
"net/http"
)
func configureServer18(h1 *http.Server, h2 *Server) error {
// No IdleTimeout to sync prior to Go 1.8.
@@ -16,3 +19,9 @@ func configureServer18(h1 *http.Server, h2 *Server) error {
func shouldLogPanic(panicValue interface{}) bool {
return panicValue != nil
}
func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
return nil
}
func reqBodyIsNoBody(io.ReadCloser) bool { return false }

View File

@@ -423,6 +423,11 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
return uint32(n + typicalHeaders*perFieldOverhead)
}
func (sc *serverConn) curOpenStreams() uint32 {
sc.serveG.check()
return sc.curClientStreams + sc.curPushedStreams
}
// stream represents a stream. This is the minimal metadata needed by
// the serve goroutine. Most of the actual stream state is owned by
// the http.Handler's goroutine in the responseWriter. Because the
@@ -448,8 +453,7 @@ type stream struct {
numTrailerValues int64
weight uint8
state streamState
sentReset bool // only true once detached from streams map
gotReset bool // only true once detacted from streams map
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
reqBuf []byte // if non-nil, body pipe buffer to return later at EOF
@@ -753,7 +757,7 @@ func (sc *serverConn) serve() {
fn(loopNum)
}
if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame {
return
}
}
@@ -869,8 +873,34 @@ func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error {
func (sc *serverConn) writeFrame(wr FrameWriteRequest) {
sc.serveG.check()
// If true, wr will not be written and wr.done will not be signaled.
var ignoreWrite bool
// We are not allowed to write frames on closed streams. RFC 7540 Section
// 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on
// a closed stream." Our server never sends PRIORITY, so that exception
// does not apply.
//
// The serverConn might close an open stream while the stream's handler
// is still running. For example, the server might close a stream when it
// receives bad data from the client. If this happens, the handler might
// attempt to write a frame after the stream has been closed (since the
// handler hasn't yet been notified of the close). In this case, we simply
// ignore the frame. The handler will notice that the stream is closed when
// it waits for the frame to be written.
//
// As an exception to this rule, we allow sending RST_STREAM after close.
// This allows us to immediately reject new streams without tracking any
// state for those streams (except for the queued RST_STREAM frame). This
// may result in duplicate RST_STREAMs in some cases, but the client should
// ignore those.
if wr.StreamID() != 0 {
_, isReset := wr.write.(StreamError)
if state, _ := sc.state(wr.StreamID()); state == stateClosed && !isReset {
ignoreWrite = true
}
}
// Don't send a 100-continue response if we've already sent headers.
// See golang.org/issue/14030.
switch wr.write.(type) {
@@ -878,6 +908,11 @@ func (sc *serverConn) writeFrame(wr FrameWriteRequest) {
wr.stream.wroteHeaders = true
case write100ContinueHeadersFrame:
if wr.stream.wroteHeaders {
// We do not need to notify wr.done because this frame is
// never written with wr.done != nil.
if wr.done != nil {
panic("wr.done != nil for write100ContinueHeadersFrame")
}
ignoreWrite = true
}
}
@@ -903,11 +938,6 @@ func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) {
case stateHalfClosedLocal:
panic("internal error: attempt to send frame on half-closed-local stream")
case stateClosed:
if st.sentReset || st.gotReset {
// Skip this frame.
sc.scheduleFrameWrite()
return
}
panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wr))
}
}
@@ -916,9 +946,7 @@ func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) {
wpp.promisedID, err = wpp.allocatePromisedID()
if err != nil {
sc.writingFrameAsync = false
if wr.done != nil {
wr.done <- err
}
wr.replyToWriter(err)
return
}
}
@@ -951,25 +979,9 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrameAsync = false
wr := res.wr
st := wr.stream
closeStream := endsStream(wr.write)
if _, ok := wr.write.(handlerPanicRST); ok {
sc.closeStream(st, errHandlerPanicked)
}
// Reply (if requested) to the blocked ServeHTTP goroutine.
if ch := wr.done; ch != nil {
select {
case ch <- res.err:
default:
panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write))
}
}
wr.write = nil // prevent use (assume it's tainted after wr.done send)
if closeStream {
if writeEndsStream(wr.write) {
st := wr.stream
if st == nil {
panic("internal error: expecting non-nil stream")
}
@@ -982,15 +994,29 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
// reading data (see possible TODO at top of
// this file), we go into closed state here
// anyway, after telling the peer we're
// hanging up on them.
st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream
errCancel := streamError(st.id, ErrCodeCancel)
sc.resetStream(errCancel)
// hanging up on them. We'll transition to
// stateClosed after the RST_STREAM frame is
// written.
st.state = stateHalfClosedLocal
sc.resetStream(streamError(st.id, ErrCodeCancel))
case stateHalfClosedRemote:
sc.closeStream(st, errHandlerComplete)
}
} else {
switch v := wr.write.(type) {
case StreamError:
// st may be unknown if the RST_STREAM was generated to reject bad input.
if st, ok := sc.streams[v.StreamID]; ok {
sc.closeStream(st, v)
}
case handlerPanicRST:
sc.closeStream(wr.stream, errHandlerPanicked)
}
}
// Reply (if requested) to unblock the ServeHTTP goroutine.
wr.replyToWriter(res.err)
sc.scheduleFrameWrite()
}
@@ -1087,8 +1113,7 @@ func (sc *serverConn) resetStream(se StreamError) {
sc.serveG.check()
sc.writeFrame(FrameWriteRequest{write: se})
if st, ok := sc.streams[se.StreamID]; ok {
st.sentReset = true
sc.closeStream(st, se)
st.resetQueued = true
}
}
@@ -1252,7 +1277,6 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
return ConnectionError(ErrCodeProtocol)
}
if st != nil {
st.gotReset = true
st.cancelCtx()
sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
}
@@ -1391,7 +1415,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// type PROTOCOL_ERROR."
return ConnectionError(ErrCodeProtocol)
}
if st == nil || state != stateOpen || st.gotTrailerHeader {
if st == nil || state != stateOpen || st.gotTrailerHeader || st.resetQueued {
// This includes sending a RST_STREAM if the stream is
// in stateHalfClosedLocal (which currently means that
// the http.Handler returned, so it's done reading &
@@ -1411,6 +1435,10 @@ func (sc *serverConn) processData(f *DataFrame) error {
sc.inflow.take(int32(f.Length))
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
if st != nil && st.resetQueued {
// Already have a stream error in flight. Don't send another.
return nil
}
return streamError(id, ErrCodeStreamClosed)
}
if st.body == nil {
@@ -1519,6 +1547,11 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// open, let it process its own HEADERS frame (trailers at this
// point, if it's valid).
if st := sc.streams[f.StreamID]; st != nil {
if st.resetQueued {
// We're sending RST_STREAM to close the stream, so don't bother
// processing this frame.
return nil
}
return st.processTrailerHeaders(f)
}
@@ -1681,7 +1714,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
} else {
sc.curClientStreams++
}
if sc.curClientStreams+sc.curPushedStreams == 1 {
if sc.curOpenStreams() == 1 {
sc.setConnState(http.StateActive)
}

View File

@@ -191,6 +191,7 @@ type clientStream struct {
ID uint32
resc chan resAndError
bufPipe pipe // buffered pipe with the flow-controlled response payload
startedWrite bool // started request body write; guarded by cc.mu
requestedGzip bool
on100 func() // optional code to run if get a 100 continue response
@@ -314,6 +315,10 @@ func authorityAddr(scheme string, authority string) (addr string) {
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
@@ -332,8 +337,10 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
}
traceGotConn(req, cc)
res, err := cc.RoundTrip(req)
if shouldRetryRequest(req, err) {
continue
if err != nil {
if req, err = shouldRetryRequest(req, err); err == nil {
continue
}
}
if err != nil {
t.vlogf("RoundTrip failure: %v", err)
@@ -355,12 +362,41 @@ func (t *Transport) CloseIdleConnections() {
var (
errClientConnClosed = errors.New("http2: client conn is closed")
errClientConnUnusable = errors.New("http2: client conn not usable")
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written")
)
func shouldRetryRequest(req *http.Request, err error) bool {
// TODO: retry GET requests (no bodies) more aggressively, if shutdown
// before response.
return err == errClientConnUnusable
// shouldRetryRequest is called by RoundTrip when a request fails to get
// response headers. It is always called with a non-nil error.
// It returns either a request to retry (either the same request, or a
// modified clone), or an error if the request can't be replayed.
func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) {
switch err {
default:
return nil, err
case errClientConnUnusable, errClientConnGotGoAway:
return req, nil
case errClientConnGotGoAwayAfterSomeReqBody:
// If the Body is nil (or http.NoBody), it's safe to reuse
// this request and its Body.
if req.Body == nil || reqBodyIsNoBody(req.Body) {
return req, nil
}
// Otherwise we depend on the Request having its GetBody
// func defined.
getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
if getBody == nil {
return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error")
}
body, err := getBody()
if err != nil {
return nil, err
}
newReq := *req
newReq.Body = body
return &newReq, nil
}
}
func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
@@ -513,6 +549,15 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
if old != nil && old.ErrCode != ErrCodeNo {
cc.goAway.ErrCode = old.ErrCode
}
last := f.LastStreamID
for streamID, cs := range cc.streams {
if streamID > last {
select {
case cs.resc <- resAndError{err: errClientConnGotGoAway}:
default:
}
}
}
}
func (cc *ClientConn) CanTakeNewRequest() bool {
@@ -773,6 +818,13 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
cs.abortRequestBodyWrite(errStopReqBodyWrite)
}
if re.err != nil {
if re.err == errClientConnGotGoAway {
cc.mu.Lock()
if cs.startedWrite {
re.err = errClientConnGotGoAwayAfterSomeReqBody
}
cc.mu.Unlock()
}
cc.forgetStreamID(cs.ID)
return nil, re.err
}
@@ -2013,6 +2065,9 @@ func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s body
resc := make(chan error, 1)
s.resc = resc
s.fn = func() {
cs.cc.mu.Lock()
cs.startedWrite = true
cs.cc.mu.Unlock()
resc <- cs.writeRequestBody(body, cs.req.Body)
}
s.delay = t.expectContinueTimeout()

View File

@@ -45,9 +45,10 @@ type writeContext interface {
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
}
// endsStream reports whether the given frame writer w will locally
// close the stream.
func endsStream(w writeFramer) bool {
// writeEndsStream reports whether w writes a frame that will transition
// the stream to a half-closed local state. This returns false for RST_STREAM,
// which closes the entire stream (not just the local half).
func writeEndsStream(w writeFramer) bool {
switch v := w.(type) {
case *writeData:
return v.endStream
@@ -57,7 +58,7 @@ func endsStream(w writeFramer) bool {
// This can only happen if the caller reuses w after it's
// been intentionally nil'ed out to prevent use. Keep this
// here to catch future refactoring breaking it.
panic("endsStream called on nil writeFramer")
panic("writeEndsStream called on nil writeFramer")
}
return false
}

Some files were not shown because too many files have changed in this diff Show More